EricBoi commited on
Commit
d8bf97a
·
1 Parent(s): ef1fe5f

Enhance DCM-Net functionality in app.py by integrating support for multiple DCM layers, updating the rendering of 3D models, and improving the user interface with sliders and buttons. Extend dcm_app.py to create and load models for DCM-3 and DCM-4, and adjust weight loading to accommodate additional models. Update requirements.txt to include new dependencies for molecule visualization.

Browse files
Files changed (4) hide show
  1. __pycache__/dcm_app.cpython-310.pyc +0 -0
  2. app.py +126 -12
  3. dcm_app.py +49 -20
  4. requirements.txt +4 -0
__pycache__/dcm_app.cpython-310.pyc CHANGED
Binary files a/__pycache__/dcm_app.cpython-310.pyc and b/__pycache__/dcm_app.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,5 +1,10 @@
1
  import streamlit as st
2
-
 
 
 
 
 
3
  try:
4
  from StringIO import StringIO
5
  except ImportError:
@@ -7,20 +12,129 @@ except ImportError:
7
 
8
  import streamlit.components.v1 as components
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from dcm_app import run_dcm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- smiles = "C1NCCCC1"
13
- results = run_dcm(smiles)
14
- print(results)
15
- st.image(results["smiles_image"])
16
- st.write("Click M to see the distributed charges")
 
 
 
17
 
18
- output = StringIO()
19
- (results["atoms"] + results["dcmol"]).write(output, format="html")
20
- components.html(output.getvalue(), width=1000, height=1000)
 
 
 
 
 
21
 
22
- output = StringIO()
23
- (results["atoms"] + results["dcmol2"]).write(output, format="html")
24
- components.html(output.getvalue(), width=1000, height=1000)
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
26
 
 
1
  import streamlit as st
2
+ from streamlit_ketcher import st_ketcher
3
+ from stmol import showmol
4
+ import py3Dmol
5
+ import io
6
+ import ase.io as ase_io
7
+ import numpy as np
8
  try:
9
  from StringIO import StringIO
10
  except ImportError:
 
12
 
13
  import streamlit.components.v1 as components
14
 
15
+
16
+ import streamlit as st
17
+ from streamlit_ketcher import st_ketcher
18
+
19
+ st.title("DCM-Net")
20
+ smiles = "OC(=O)[C@@H]1CCCN1"
21
+ # Create a molecule editor
22
+ smiles = st_ketcher(smiles)
23
+
24
+ # Output the SMILES string of the drawn molecule
25
+ st.write(f"SMILES: {smiles}")
26
+
27
+
28
+ # color charges by magnitude
29
+ def color_charges(mono, nDCM, atoms):
30
+ from matplotlib import cm
31
+ from matplotlib.colors import Normalize
32
+
33
+ norm = Normalize(vmin=-1, vmax=1)
34
+ cmap = cm.get_cmap("bwr")
35
+ mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
36
+ elem = atoms.numbers
37
+ pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)])
38
+ from ase.data.colors import jmol_colors
39
+
40
+ atomcolors = [jmol_colors[_] for _ in elem]
41
+ atomcolors_ = []
42
+ for _ in atomcolors:
43
+ atomcolors_.append(np.append(_, 0.015))
44
+ dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM])
45
+ return dcmcolors
46
+
47
+
48
  from dcm_app import run_dcm
49
+ # Create a 3D view
50
+ def render_3d(atoms, dcm_charges=None, dcmcolors=None):
51
+ view = py3Dmol.view()
52
+
53
+ view.startjs += '''\n
54
+ let customColorize = function(atom){
55
+ // attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem
56
+ if (atom.elem === 'X'){
57
+ return "#FF0000" // red
58
+ }
59
+ else if (atom.elem === 'He'){
60
+ return "#0000FF" // blue
61
+ }
62
+ else{
63
+ return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"});
64
+ }
65
+ }
66
+ \n'''
67
+
68
+ if atoms is not None:
69
+ # Write structure to xyz file.
70
+ xyz = io.StringIO()
71
+ ase_io.write(xyz, atoms, format='xyz')
72
+
73
+ view.addModel(xyz.getvalue(), 'xyz')
74
+ view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}})
75
+
76
+ if dcm_charges is not None:
77
+ xyz_dcm = io.StringIO()
78
+ ase_io.write(xyz_dcm, dcm_charges, format='xyz')
79
+ view.addModel(xyz_dcm.getvalue(), 'xyz')
80
+ view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}})
81
+ view.startjs = view.startjs.replace('"customColorize"', 'customColorize')
82
+ view.zoomTo()
83
+ return view
84
+
85
+ # create a slider for n_dcm
86
+ n_dcm = st.slider("Number of DCM layers", min_value=1, max_value=4, value=1)
87
+
88
+ # create a button to run the DCM-Net
89
+ if st.button("Run DCM-Net"):
90
+ results = run_dcm(smiles, n_dcm)
91
+ print(results)
92
+
93
+ st.subheader("Model Output")
94
+ # ase object
95
+ ase_atoms = results["atoms"]
96
+
97
+ # Display the 3D model
98
+
99
+ col1, col2 = st.columns(2)
100
+ with col1:
101
+ showmol(render_3d(ase_atoms), height=300, width=1000)
102
+ with col2:
103
+ st.image(results["smiles_image"])
104
 
105
+ st.subheader("DCM-1")
106
+ dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms)
107
+ # two columns showing combined and just the dcmol
108
+ col1, col2 = st.columns(2)
109
+ with col1:
110
+ showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=1000)
111
+ with col2:
112
+ showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=1000)
113
 
114
+ if n_dcm >= 2:
115
+ st.subheader("DCM-2")
116
+ dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms)
117
+ col1, col2 = st.columns(2)
118
+ with col1:
119
+ showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=1000)
120
+ with col2:
121
+ showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=1000)
122
 
123
+ if n_dcm >= 3:
124
+ st.subheader("DCM-3")
125
+ dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms)
126
+ col1, col2 = st.columns(2)
127
+ with col1:
128
+ showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=1000)
129
+ with col2:
130
+ showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=1000)
131
 
132
+ if n_dcm >= 4:
133
+ st.subheader("DCM-4")
134
+ dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms)
135
+ col1, col2 = st.columns(2)
136
+ with col1:
137
+ showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=1000)
138
+ with col2:
139
+ showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=1000)
140
 
dcm_app.py CHANGED
@@ -48,7 +48,23 @@ def create_models():
48
  cutoff=cutoff,
49
  n_dcm=2,
50
  )
51
- return dcm1, dcm2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def get_grid_points(coordinates):
@@ -83,7 +99,7 @@ def restore_arrays(obj):
83
  if any(isinstance(value, dict) for value in restored):
84
  return restored
85
  try:
86
- return np.array(restored)
87
  except Exception:
88
  return restored
89
  return obj
@@ -96,9 +112,11 @@ def load_json_dict(path):
96
 
97
 
98
  def load_weights():
99
- dcm1_weights = load_json_dict("wbs/best_0.0_params.json")
100
- dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params.json")
101
- return dcm1_weights, dcm2_weights
 
 
102
 
103
 
104
  def prepare_inputs(smiles):
@@ -211,30 +229,41 @@ def normalize_batch(batch):
211
 
212
  return batch
213
 
214
- def run_dcm(smiles="C1NCCCC1"):
215
- dcm1, dcm2 = create_models()
216
- dcm1_weights, dcm2_weights = load_weights()
217
  data_batch, smiles_image = prepare_inputs(smiles)
218
 
219
  batch_size = 1
220
  psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
221
  batch = normalize_batch(psi4_test_batches[0])
222
 
223
- mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
224
- mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
 
225
 
226
- dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
227
- dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
228
-
229
- return {
230
- "smiles_image": smiles_image,
231
  "atoms": dcm1_results["atoms"],
232
  "dcmol": dcm1_results["dcmol"],
233
- "dcmol2": dcm2_results["dcmol"],
 
234
  }
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- # if __name__ == "__main__":
238
- # smiles = "C1NCCCC1"
239
- # results = run_dcm(smiles)
240
- # print(results)
 
48
  cutoff=cutoff,
49
  n_dcm=2,
50
  )
51
+ dcm3 = MessagePassingModel(
52
+ features=features,
53
+ max_degree=max_degree,
54
+ num_iterations=num_iterations,
55
+ num_basis_functions=num_basis_functions,
56
+ cutoff=cutoff,
57
+ n_dcm=3,
58
+ )
59
+ dcm4 = MessagePassingModel(
60
+ features=features,
61
+ max_degree=max_degree,
62
+ num_iterations=num_iterations,
63
+ num_basis_functions=num_basis_functions,
64
+ cutoff=cutoff,
65
+ n_dcm=4,
66
+ )
67
+ return dcm1, dcm2, dcm3, dcm4
68
 
69
 
70
  def get_grid_points(coordinates):
 
99
  if any(isinstance(value, dict) for value in restored):
100
  return restored
101
  try:
102
+ return jnp.asarray(restored)
103
  except Exception:
104
  return restored
105
  return obj
 
112
 
113
 
114
  def load_weights():
115
+ dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json")
116
+ dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json")
117
+ dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json")
118
+ dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json")
119
+ return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights
120
 
121
 
122
  def prepare_inputs(smiles):
 
229
 
230
  return batch
231
 
232
+ def run_dcm(smiles="C1NCCCC1", n_dcm=1):
233
+ dcm1, dcm2, dcm3, dcm4 = create_models()
234
+ dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights()
235
  data_batch, smiles_image = prepare_inputs(smiles)
236
 
237
  batch_size = 1
238
  psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
239
  batch = normalize_batch(psi4_test_batches[0])
240
 
241
+ if n_dcm >= 1:
242
+ mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
243
+ dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
244
 
245
+ results = {
 
 
 
 
246
  "atoms": dcm1_results["atoms"],
247
  "dcmol": dcm1_results["dcmol"],
248
+ "smiles_image": smiles_image,
249
+ "mono_dc1": mono_dc1,
250
  }
251
 
252
+ if n_dcm >= 2:
253
+ mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
254
+ dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
255
+ results["dcmol2"] = dcm2_results["dcmol"]
256
+ results["mono_dc2"] = mono_dc2
257
+ if n_dcm >= 3:
258
+ mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size)
259
+ dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3)
260
+ results["dcmol3"] = dcm3_results["dcmol"]
261
+ results["mono_dc3"] = mono_dc3
262
+ if n_dcm >= 4:
263
+ mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size)
264
+ dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4)
265
+ results["dcmol4"] = dcm4_results["dcmol"]
266
+ results["mono_dc4"] = mono_dc4
267
+
268
 
269
+ return results
 
 
 
requirements.txt CHANGED
@@ -5,3 +5,7 @@ scipy
5
  rdkit
6
  plotnine
7
  dcmnet @ git+https://github.com/EricBoittier/dcmnet@main
 
 
 
 
 
5
  rdkit
6
  plotnine
7
  dcmnet @ git+https://github.com/EricBoittier/dcmnet@main
8
+ streamlit_ketcher
9
+ stmol
10
+ py3Dmol==2.0.0.post2
11
+ ipython_genutils