Enhance DCM-Net functionality in app.py by integrating support for multiple DCM layers, updating the rendering of 3D models, and improving the user interface with sliders and buttons. Extend dcm_app.py to create and load models for DCM-3 and DCM-4, and adjust weight loading to accommodate additional models. Update requirements.txt to include new dependencies for molecule visualization.
Browse files- __pycache__/dcm_app.cpython-310.pyc +0 -0
- app.py +126 -12
- dcm_app.py +49 -20
- requirements.txt +4 -0
__pycache__/dcm_app.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/dcm_app.cpython-310.pyc and b/__pycache__/dcm_app.cpython-310.pyc differ
|
|
|
app.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
try:
|
| 4 |
from StringIO import StringIO
|
| 5 |
except ImportError:
|
|
@@ -7,20 +12,129 @@ except ImportError:
|
|
| 7 |
|
| 8 |
import streamlit.components.v1 as components
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from dcm_app import run_dcm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
st.
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from streamlit_ketcher import st_ketcher
|
| 3 |
+
from stmol import showmol
|
| 4 |
+
import py3Dmol
|
| 5 |
+
import io
|
| 6 |
+
import ase.io as ase_io
|
| 7 |
+
import numpy as np
|
| 8 |
try:
|
| 9 |
from StringIO import StringIO
|
| 10 |
except ImportError:
|
|
|
|
| 12 |
|
| 13 |
import streamlit.components.v1 as components
|
| 14 |
|
| 15 |
+
|
| 16 |
+
import streamlit as st
|
| 17 |
+
from streamlit_ketcher import st_ketcher
|
| 18 |
+
|
| 19 |
+
st.title("DCM-Net")
|
| 20 |
+
smiles = "OC(=O)[C@@H]1CCCN1"
|
| 21 |
+
# Create a molecule editor
|
| 22 |
+
smiles = st_ketcher(smiles)
|
| 23 |
+
|
| 24 |
+
# Output the SMILES string of the drawn molecule
|
| 25 |
+
st.write(f"SMILES: {smiles}")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# color charges by magnitude
|
| 29 |
+
def color_charges(mono, nDCM, atoms):
|
| 30 |
+
from matplotlib import cm
|
| 31 |
+
from matplotlib.colors import Normalize
|
| 32 |
+
|
| 33 |
+
norm = Normalize(vmin=-1, vmax=1)
|
| 34 |
+
cmap = cm.get_cmap("bwr")
|
| 35 |
+
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
|
| 36 |
+
elem = atoms.numbers
|
| 37 |
+
pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)])
|
| 38 |
+
from ase.data.colors import jmol_colors
|
| 39 |
+
|
| 40 |
+
atomcolors = [jmol_colors[_] for _ in elem]
|
| 41 |
+
atomcolors_ = []
|
| 42 |
+
for _ in atomcolors:
|
| 43 |
+
atomcolors_.append(np.append(_, 0.015))
|
| 44 |
+
dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM])
|
| 45 |
+
return dcmcolors
|
| 46 |
+
|
| 47 |
+
|
| 48 |
from dcm_app import run_dcm
|
| 49 |
+
# Create a 3D view
|
| 50 |
+
def render_3d(atoms, dcm_charges=None, dcmcolors=None):
|
| 51 |
+
view = py3Dmol.view()
|
| 52 |
+
|
| 53 |
+
view.startjs += '''\n
|
| 54 |
+
let customColorize = function(atom){
|
| 55 |
+
// attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem
|
| 56 |
+
if (atom.elem === 'X'){
|
| 57 |
+
return "#FF0000" // red
|
| 58 |
+
}
|
| 59 |
+
else if (atom.elem === 'He'){
|
| 60 |
+
return "#0000FF" // blue
|
| 61 |
+
}
|
| 62 |
+
else{
|
| 63 |
+
return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"});
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
\n'''
|
| 67 |
+
|
| 68 |
+
if atoms is not None:
|
| 69 |
+
# Write structure to xyz file.
|
| 70 |
+
xyz = io.StringIO()
|
| 71 |
+
ase_io.write(xyz, atoms, format='xyz')
|
| 72 |
+
|
| 73 |
+
view.addModel(xyz.getvalue(), 'xyz')
|
| 74 |
+
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}})
|
| 75 |
+
|
| 76 |
+
if dcm_charges is not None:
|
| 77 |
+
xyz_dcm = io.StringIO()
|
| 78 |
+
ase_io.write(xyz_dcm, dcm_charges, format='xyz')
|
| 79 |
+
view.addModel(xyz_dcm.getvalue(), 'xyz')
|
| 80 |
+
view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}})
|
| 81 |
+
view.startjs = view.startjs.replace('"customColorize"', 'customColorize')
|
| 82 |
+
view.zoomTo()
|
| 83 |
+
return view
|
| 84 |
+
|
| 85 |
+
# create a slider for n_dcm
|
| 86 |
+
n_dcm = st.slider("Number of DCM layers", min_value=1, max_value=4, value=1)
|
| 87 |
+
|
| 88 |
+
# create a button to run the DCM-Net
|
| 89 |
+
if st.button("Run DCM-Net"):
|
| 90 |
+
results = run_dcm(smiles, n_dcm)
|
| 91 |
+
print(results)
|
| 92 |
+
|
| 93 |
+
st.subheader("Model Output")
|
| 94 |
+
# ase object
|
| 95 |
+
ase_atoms = results["atoms"]
|
| 96 |
+
|
| 97 |
+
# Display the 3D model
|
| 98 |
+
|
| 99 |
+
col1, col2 = st.columns(2)
|
| 100 |
+
with col1:
|
| 101 |
+
showmol(render_3d(ase_atoms), height=300, width=1000)
|
| 102 |
+
with col2:
|
| 103 |
+
st.image(results["smiles_image"])
|
| 104 |
|
| 105 |
+
st.subheader("DCM-1")
|
| 106 |
+
dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms)
|
| 107 |
+
# two columns showing combined and just the dcmol
|
| 108 |
+
col1, col2 = st.columns(2)
|
| 109 |
+
with col1:
|
| 110 |
+
showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=1000)
|
| 111 |
+
with col2:
|
| 112 |
+
showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=1000)
|
| 113 |
|
| 114 |
+
if n_dcm >= 2:
|
| 115 |
+
st.subheader("DCM-2")
|
| 116 |
+
dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms)
|
| 117 |
+
col1, col2 = st.columns(2)
|
| 118 |
+
with col1:
|
| 119 |
+
showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=1000)
|
| 120 |
+
with col2:
|
| 121 |
+
showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=1000)
|
| 122 |
|
| 123 |
+
if n_dcm >= 3:
|
| 124 |
+
st.subheader("DCM-3")
|
| 125 |
+
dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms)
|
| 126 |
+
col1, col2 = st.columns(2)
|
| 127 |
+
with col1:
|
| 128 |
+
showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=1000)
|
| 129 |
+
with col2:
|
| 130 |
+
showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=1000)
|
| 131 |
|
| 132 |
+
if n_dcm >= 4:
|
| 133 |
+
st.subheader("DCM-4")
|
| 134 |
+
dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms)
|
| 135 |
+
col1, col2 = st.columns(2)
|
| 136 |
+
with col1:
|
| 137 |
+
showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=1000)
|
| 138 |
+
with col2:
|
| 139 |
+
showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=1000)
|
| 140 |
|
dcm_app.py
CHANGED
|
@@ -48,7 +48,23 @@ def create_models():
|
|
| 48 |
cutoff=cutoff,
|
| 49 |
n_dcm=2,
|
| 50 |
)
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def get_grid_points(coordinates):
|
|
@@ -83,7 +99,7 @@ def restore_arrays(obj):
|
|
| 83 |
if any(isinstance(value, dict) for value in restored):
|
| 84 |
return restored
|
| 85 |
try:
|
| 86 |
-
return
|
| 87 |
except Exception:
|
| 88 |
return restored
|
| 89 |
return obj
|
|
@@ -96,9 +112,11 @@ def load_json_dict(path):
|
|
| 96 |
|
| 97 |
|
| 98 |
def load_weights():
|
| 99 |
-
dcm1_weights = load_json_dict("wbs/best_0.
|
| 100 |
-
dcm2_weights = load_json_dict("wbs/dcm2-best_1000.
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def prepare_inputs(smiles):
|
|
@@ -211,30 +229,41 @@ def normalize_batch(batch):
|
|
| 211 |
|
| 212 |
return batch
|
| 213 |
|
| 214 |
-
def run_dcm(smiles="C1NCCCC1"):
|
| 215 |
-
dcm1, dcm2 = create_models()
|
| 216 |
-
dcm1_weights, dcm2_weights = load_weights()
|
| 217 |
data_batch, smiles_image = prepare_inputs(smiles)
|
| 218 |
|
| 219 |
batch_size = 1
|
| 220 |
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
|
| 221 |
batch = normalize_batch(psi4_test_batches[0])
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
|
| 228 |
-
|
| 229 |
-
return {
|
| 230 |
-
"smiles_image": smiles_image,
|
| 231 |
"atoms": dcm1_results["atoms"],
|
| 232 |
"dcmol": dcm1_results["dcmol"],
|
| 233 |
-
"
|
|
|
|
| 234 |
}
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
-
|
| 238 |
-
# smiles = "C1NCCCC1"
|
| 239 |
-
# results = run_dcm(smiles)
|
| 240 |
-
# print(results)
|
|
|
|
| 48 |
cutoff=cutoff,
|
| 49 |
n_dcm=2,
|
| 50 |
)
|
| 51 |
+
dcm3 = MessagePassingModel(
|
| 52 |
+
features=features,
|
| 53 |
+
max_degree=max_degree,
|
| 54 |
+
num_iterations=num_iterations,
|
| 55 |
+
num_basis_functions=num_basis_functions,
|
| 56 |
+
cutoff=cutoff,
|
| 57 |
+
n_dcm=3,
|
| 58 |
+
)
|
| 59 |
+
dcm4 = MessagePassingModel(
|
| 60 |
+
features=features,
|
| 61 |
+
max_degree=max_degree,
|
| 62 |
+
num_iterations=num_iterations,
|
| 63 |
+
num_basis_functions=num_basis_functions,
|
| 64 |
+
cutoff=cutoff,
|
| 65 |
+
n_dcm=4,
|
| 66 |
+
)
|
| 67 |
+
return dcm1, dcm2, dcm3, dcm4
|
| 68 |
|
| 69 |
|
| 70 |
def get_grid_points(coordinates):
|
|
|
|
| 99 |
if any(isinstance(value, dict) for value in restored):
|
| 100 |
return restored
|
| 101 |
try:
|
| 102 |
+
return jnp.asarray(restored)
|
| 103 |
except Exception:
|
| 104 |
return restored
|
| 105 |
return obj
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
def load_weights():
|
| 115 |
+
dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json")
|
| 116 |
+
dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json")
|
| 117 |
+
dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json")
|
| 118 |
+
dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json")
|
| 119 |
+
return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights
|
| 120 |
|
| 121 |
|
| 122 |
def prepare_inputs(smiles):
|
|
|
|
| 229 |
|
| 230 |
return batch
|
| 231 |
|
| 232 |
+
def run_dcm(smiles="C1NCCCC1", n_dcm=1):
|
| 233 |
+
dcm1, dcm2, dcm3, dcm4 = create_models()
|
| 234 |
+
dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights()
|
| 235 |
data_batch, smiles_image = prepare_inputs(smiles)
|
| 236 |
|
| 237 |
batch_size = 1
|
| 238 |
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
|
| 239 |
batch = normalize_batch(psi4_test_batches[0])
|
| 240 |
|
| 241 |
+
if n_dcm >= 1:
|
| 242 |
+
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
|
| 243 |
+
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
|
| 244 |
|
| 245 |
+
results = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
"atoms": dcm1_results["atoms"],
|
| 247 |
"dcmol": dcm1_results["dcmol"],
|
| 248 |
+
"smiles_image": smiles_image,
|
| 249 |
+
"mono_dc1": mono_dc1,
|
| 250 |
}
|
| 251 |
|
| 252 |
+
if n_dcm >= 2:
|
| 253 |
+
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
|
| 254 |
+
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
|
| 255 |
+
results["dcmol2"] = dcm2_results["dcmol"]
|
| 256 |
+
results["mono_dc2"] = mono_dc2
|
| 257 |
+
if n_dcm >= 3:
|
| 258 |
+
mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size)
|
| 259 |
+
dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3)
|
| 260 |
+
results["dcmol3"] = dcm3_results["dcmol"]
|
| 261 |
+
results["mono_dc3"] = mono_dc3
|
| 262 |
+
if n_dcm >= 4:
|
| 263 |
+
mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size)
|
| 264 |
+
dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4)
|
| 265 |
+
results["dcmol4"] = dcm4_results["dcmol"]
|
| 266 |
+
results["mono_dc4"] = mono_dc4
|
| 267 |
+
|
| 268 |
|
| 269 |
+
return results
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -5,3 +5,7 @@ scipy
|
|
| 5 |
rdkit
|
| 6 |
plotnine
|
| 7 |
dcmnet @ git+https://github.com/EricBoittier/dcmnet@main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
rdkit
|
| 6 |
plotnine
|
| 7 |
dcmnet @ git+https://github.com/EricBoittier/dcmnet@main
|
| 8 |
+
streamlit_ketcher
|
| 9 |
+
stmol
|
| 10 |
+
py3Dmol==2.0.0.post2
|
| 11 |
+
ipython_genutils
|