DCMNet / dcm_app.py
EricBoi's picture
Refactor app.py to integrate DCM functionality and streamline molecule visualization; update requirements.txt to include plotnine and correct dcmnet repository URL.
88cd151
raw
history blame
5.68 kB
import ase
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from scipy.spatial.distance import cdist
from dcmnet.data import prepare_batches
from dcmnet.modules import MessagePassingModel
from dcmnet.loss import (
esp_loss_eval,
esp_loss_pots,
esp_mono_loss_pots,
get_predictions,
)
from dcmnet.plotting import evaluate_dc, create_plots2
from dcmnet.multimodel import get_atoms_dcmol
from dcmnet.multipoles import plot_3d
from dcmnet.utils import apply_model, clip_colors, reshape_dipole
RANDOM_NUMBER = 0
data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
def create_models():
dcm1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
dcm2 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=2,
)
return dcm1, dcm2
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(
*[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])]
)
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]
]
return grid_points
def load_weights():
dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
return dcm1_weights, dcm2_weights
def prepare_inputs(smiles):
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for atom in smiles_mol.GetAtoms():
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
vdw_surface = surface
max_n_atoms = 60
max_grid_points = 3143
try:
z_values = [np.array([int(_) for _ in elements])]
except Exception:
z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))])
pad_coords = np.array(
[
np.pad(
coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0))
)
]
)
pad_vdw_surface = []
padded_surface = np.pad(
vdw_surface,
((0, max_grid_points - len(vdw_surface)), (0, 0)),
"constant",
constant_values=(0, 10000),
)
pad_vdw_surface.append(padded_surface)
pad_vdw_surface = np.array(pad_vdw_surface)
n_atoms = np.sum(pad_z != 0)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_z),
Z=jnp.asarray(pad_z),
positions=jnp.asarray(pad_coords),
R=jnp.asarray(pad_coords),
# N is the number of atoms
N=jnp.asarray([n_atoms]),
mono=jnp.asarray(pad_z),
ngrid=jnp.array([len(vdw_surface)]),
n_grid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_),
)
return data_batch, smiles_image
def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
esp_errors, mono_pred, _, _ = evaluate_dc(
batch,
dipo_dc1,
mono_dc1,
batch_size,
1,
plot=False,
)
atoms, dcmol, grid, esp, esp_dc_pred, idx_cut = create_plots2(
mono_dc1, dipo_dc1, batch, batch_size, 1
)
outDict = {
"mono": mono_dc1,
"dipo": dipo_dc1,
"esp_errors": esp_errors,
"atoms": atoms,
"dcmol": dcmol,
"grid": grid,
"esp": esp,
"esp_dc_pred": esp_dc_pred,
"esp_mono_pred": mono_pred,
"idx_cut": idx_cut,
}
return outDict
def run_dcm(smiles="C1NCCCC1"):
dcm1, dcm2 = create_models()
dcm1_weights, dcm2_weights = load_weights()
data_batch, smiles_image = prepare_inputs(smiles)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batch = psi4_test_batches[0]
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size)
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size)
return {
"smiles_image": smiles_image,
"atoms": dcm1_results["atoms"],
"dcmol": dcm1_results["dcmol"],
"dcmol2": dcm2_results["dcmol"],
}
if __name__ == "__main__":
smiles = "C1NCCCC1"
results = run_dcm(smiles)
print(results)