EricBoi commited on
Commit
88cd151
·
1 Parent(s): 77fe6f0

Refactor app.py to integrate DCM functionality and streamline molecule visualization; update requirements.txt to include plotnine and correct dcmnet repository URL.

Browse files
__pycache__/dcm_app.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
__pycache__/dcm_app.cpython-313.pyc ADDED
Binary file (6.55 kB). View file
 
app.py CHANGED
@@ -1,170 +1,26 @@
1
  import streamlit as st
2
 
3
- from ase.visualize import view
4
-
5
  try:
6
  from StringIO import StringIO
7
  except ImportError:
8
  from io import StringIO
9
 
10
  import streamlit.components.v1 as components
11
- from scipy.spatial.distance import cdist
12
-
13
- import ase
14
-
15
- import functools
16
- import e3x
17
- from flax import linen as nn
18
- import jax
19
- import jax.numpy as jnp
20
- import matplotlib.pyplot as plt
21
- import numpy as np
22
- import optax
23
-
24
-
25
- import pandas as pd
26
- from dcmnet.modules import MessagePassingModel
27
- from dcmnet.utils import clip_colors, apply_model
28
- from dcmnet.data import prepare_batches
29
- from dcmnet.plotting import plot_model
30
-
31
- RANDOM_NUMBER = 0
32
- filename = "test"
33
- data_key, train_key = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
34
-
35
-
36
- # Model hyperparameters.
37
- features = 16
38
- max_degree = 2
39
- num_iterations = 2
40
- num_basis_functions = 8
41
- cutoff = 4.0
42
-
43
- # Create models
44
- DCM1 = MessagePassingModel(
45
- features=features,
46
- max_degree=max_degree,
47
- num_iterations=num_iterations,
48
- num_basis_functions=num_basis_functions,
49
- cutoff=cutoff,
50
- n_dcm=1,
51
- )
52
-
53
- # Create models
54
- DCM2 = MessagePassingModel(
55
- features=features,
56
- max_degree=max_degree,
57
- num_iterations=num_iterations,
58
- num_basis_functions=num_basis_functions,
59
- cutoff=cutoff,
60
- n_dcm=2,
61
- )
62
-
63
-
64
- from rdkit import Chem
65
- from rdkit.Chem import AllChem
66
- from rdkit.Chem import Draw
67
-
68
- def get_grid_points(coordinates):
69
- """
70
- create a uniform grid of points around the molecule,
71
- starting from minimum and maximum coordinates of the molecule (plus minus some padding)
72
- :param coordinates:
73
- :return:
74
- """
75
- bounds = np.array([np.min(coordinates, axis=0),
76
- np.max(coordinates, axis=0)])
77
- padding = 3.0
78
- bounds = bounds + np.array([-1, 1])[:, None] * padding
79
- grid_points = np.meshgrid(*[np.linspace(a, b, 15)
80
- for a, b in zip(bounds[0], bounds[1])])
81
-
82
- grid_points = np.stack(grid_points, axis=0)
83
- grid_points = np.reshape(grid_points.T, [-1, 3])
84
- # exclude points that are too close to the molecule
85
- grid_points = grid_points[
86
- #np.where(np.all(cdist(grid_points, coordinates) >= (2.0 - 1e-1), axis=-1))[0]]
87
- np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]]
88
-
89
- return grid_points
90
-
91
-
92
- dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
93
- dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
94
-
95
- smiles = 'C1NCCCC1'
96
-
97
- smiles_mol = Chem.MolFromSmiles(smiles)
98
- rdkit_mol = Chem.AddHs(smiles_mol)
99
- elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
100
- # Generate a conformation
101
- AllChem.EmbedMolecule(rdkit_mol)
102
- coordinates = rdkit_mol.GetConformer(0).GetPositions()
103
- surface = get_grid_points(coordinates)
104
-
105
- for i, atom in enumerate(smiles_mol.GetAtoms()):
106
- # For each atom, set the property "molAtomMapNumber" to a custom number, let's say, the index of the atom in the molecule
107
- atom.SetProp("atomNote", str(atom.GetIdx()))
108
-
109
- smiles_image = Draw.MolToImage(smiles_mol)
110
-
111
- # display molecule
112
- st.image(smiles_image)
113
-
114
-
115
- vdw_surface = surface
116
- max_N_atoms = 60
117
- max_grid_points = 3143
118
- max_grid_points - len(vdw_surface)
119
- try:
120
- Z = [np.array([int(_) for _ in elements])]
121
- except:
122
- Z = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
123
- pad_Z = np.array([np.pad(Z[0], ((0,max_N_atoms - len(Z[0]))))])
124
- pad_coords = np.array([np.pad(coordinates, ((0, max_N_atoms - len(coordinates)), (0,0)))])
125
-
126
- pad_vdw_surface = []
127
- _ = np.pad(vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0,0)), "constant", constant_values=(0, 10000))
128
- pad_vdw_surface.append(_)
129
- pad_vdw_surface = np.array(pad_vdw_surface)
130
-
131
-
132
- data_batch = dict(
133
- atomic_numbers=jnp.asarray(pad_Z),
134
- positions=jnp.asarray(pad_coords),
135
- mono=jnp.asarray(pad_Z),
136
- ngrid=jnp.array([len(vdw_surface)]),
137
- esp=jnp.asarray([np.zeros(max_grid_points)]),
138
- vdw_surface=jnp.asarray(pad_vdw_surface),
139
- )
140
-
141
- batch_size = 1
142
-
143
- psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
144
-
145
- batchID = 0
146
- errors_train = []
147
- batch = psi4_test_batches[batchID]
148
-
149
- #mono, dipo = apply_model(DCM1, test_weights, batch, batch_size)
150
- dcm1results = plot_model(DCM1, dcm1_weights, batch, batch_size, 1, plot=False)
151
- dcm2results = plot_model(DCM2, dcm2_weights, batch, batch_size, 2, plot=False)
152
-
153
- atoms = dcm1results["atoms"]
154
- dcmol = dcm1results["dcmol"]
155
- dcmol2 = dcm2results["dcmol"]
156
 
 
157
 
 
 
158
 
 
159
  st.write("Click M to see the distributed charges")
 
160
  output = StringIO()
161
- (atoms+dcmol).write(output, format="html")
162
- data = output.getvalue()
163
- components.html(data, width=1000, height=1000)
164
 
165
  output = StringIO()
166
- (atoms+dcmol2).write(output, format="html")
167
- data = output.getvalue()
168
- components.html(data, width=1000, height=1000)
169
 
170
 
 
1
  import streamlit as st
2
 
 
 
3
  try:
4
  from StringIO import StringIO
5
  except ImportError:
6
  from io import StringIO
7
 
8
  import streamlit.components.v1 as components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ from dcm_app import run_dcm
11
 
12
+ smiles = "C1NCCCC1"
13
+ results = run_dcm(smiles)
14
 
15
+ st.image(results["smiles_image"])
16
  st.write("Click M to see the distributed charges")
17
+
18
  output = StringIO()
19
+ (results["atoms"] + results["dcmol"]).write(output, format="html")
20
+ components.html(output.getvalue(), width=1000, height=1000)
 
21
 
22
  output = StringIO()
23
+ (results["atoms"] + results["dcmol2"]).write(output, format="html")
24
+ components.html(output.getvalue(), width=1000, height=1000)
 
25
 
26
 
dcm_app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ase
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem
8
+ from rdkit.Chem import Draw
9
+ from scipy.spatial.distance import cdist
10
+
11
+ from dcmnet.data import prepare_batches
12
+ from dcmnet.modules import MessagePassingModel
13
+ from dcmnet.loss import (
14
+ esp_loss_eval,
15
+ esp_loss_pots,
16
+ esp_mono_loss_pots,
17
+ get_predictions,
18
+ )
19
+ from dcmnet.plotting import evaluate_dc, create_plots2
20
+ from dcmnet.multimodel import get_atoms_dcmol
21
+ from dcmnet.multipoles import plot_3d
22
+ from dcmnet.utils import apply_model, clip_colors, reshape_dipole
23
+ RANDOM_NUMBER = 0
24
+ data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
25
+
26
+ # Model hyperparameters.
27
+ features = 16
28
+ max_degree = 2
29
+ num_iterations = 2
30
+ num_basis_functions = 8
31
+ cutoff = 4.0
32
+
33
+
34
+ def create_models():
35
+ dcm1 = MessagePassingModel(
36
+ features=features,
37
+ max_degree=max_degree,
38
+ num_iterations=num_iterations,
39
+ num_basis_functions=num_basis_functions,
40
+ cutoff=cutoff,
41
+ n_dcm=1,
42
+ )
43
+ dcm2 = MessagePassingModel(
44
+ features=features,
45
+ max_degree=max_degree,
46
+ num_iterations=num_iterations,
47
+ num_basis_functions=num_basis_functions,
48
+ cutoff=cutoff,
49
+ n_dcm=2,
50
+ )
51
+ return dcm1, dcm2
52
+
53
+
54
+ def get_grid_points(coordinates):
55
+ """
56
+ create a uniform grid of points around the molecule,
57
+ starting from minimum and maximum coordinates of the molecule (plus minus some padding)
58
+ :param coordinates:
59
+ :return:
60
+ """
61
+ bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)])
62
+ padding = 3.0
63
+ bounds = bounds + np.array([-1, 1])[:, None] * padding
64
+ grid_points = np.meshgrid(
65
+ *[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])]
66
+ )
67
+
68
+ grid_points = np.stack(grid_points, axis=0)
69
+ grid_points = np.reshape(grid_points.T, [-1, 3])
70
+ # exclude points that are too close to the molecule
71
+ grid_points = grid_points[
72
+ np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]
73
+ ]
74
+
75
+ return grid_points
76
+
77
+
78
+ def load_weights():
79
+ dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
80
+ dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
81
+ return dcm1_weights, dcm2_weights
82
+
83
+
84
+ def prepare_inputs(smiles):
85
+ smiles_mol = Chem.MolFromSmiles(smiles)
86
+ rdkit_mol = Chem.AddHs(smiles_mol)
87
+ elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
88
+ AllChem.EmbedMolecule(rdkit_mol)
89
+ coordinates = rdkit_mol.GetConformer(0).GetPositions()
90
+ surface = get_grid_points(coordinates)
91
+
92
+ for atom in smiles_mol.GetAtoms():
93
+ atom.SetProp("atomNote", str(atom.GetIdx()))
94
+
95
+ smiles_image = Draw.MolToImage(smiles_mol)
96
+
97
+ vdw_surface = surface
98
+ max_n_atoms = 60
99
+ max_grid_points = 3143
100
+ try:
101
+ z_values = [np.array([int(_) for _ in elements])]
102
+ except Exception:
103
+ z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
104
+
105
+ pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))])
106
+ pad_coords = np.array(
107
+ [
108
+ np.pad(
109
+ coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0))
110
+ )
111
+ ]
112
+ )
113
+
114
+ pad_vdw_surface = []
115
+ padded_surface = np.pad(
116
+ vdw_surface,
117
+ ((0, max_grid_points - len(vdw_surface)), (0, 0)),
118
+ "constant",
119
+ constant_values=(0, 10000),
120
+ )
121
+ pad_vdw_surface.append(padded_surface)
122
+ pad_vdw_surface = np.array(pad_vdw_surface)
123
+ n_atoms = np.sum(pad_z != 0)
124
+ data_batch = dict(
125
+ atomic_numbers=jnp.asarray(pad_z),
126
+ Z=jnp.asarray(pad_z),
127
+ positions=jnp.asarray(pad_coords),
128
+ R=jnp.asarray(pad_coords),
129
+ # N is the number of atoms
130
+ N=jnp.asarray([n_atoms]),
131
+ mono=jnp.asarray(pad_z),
132
+ ngrid=jnp.array([len(vdw_surface)]),
133
+ n_grid=jnp.array([len(vdw_surface)]),
134
+ esp=jnp.asarray([np.zeros(max_grid_points)]),
135
+ vdw_surface=jnp.asarray(pad_vdw_surface),
136
+ espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_),
137
+ )
138
+
139
+ return data_batch, smiles_image
140
+
141
+ def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
142
+ esp_errors, mono_pred, _, _ = evaluate_dc(
143
+ batch,
144
+ dipo_dc1,
145
+ mono_dc1,
146
+ batch_size,
147
+ 1,
148
+ plot=False,
149
+
150
+ )
151
+
152
+ atoms, dcmol, grid, esp, esp_dc_pred, idx_cut = create_plots2(
153
+ mono_dc1, dipo_dc1, batch, batch_size, 1
154
+ )
155
+ outDict = {
156
+ "mono": mono_dc1,
157
+ "dipo": dipo_dc1,
158
+ "esp_errors": esp_errors,
159
+ "atoms": atoms,
160
+ "dcmol": dcmol,
161
+ "grid": grid,
162
+ "esp": esp,
163
+ "esp_dc_pred": esp_dc_pred,
164
+ "esp_mono_pred": mono_pred,
165
+ "idx_cut": idx_cut,
166
+ }
167
+
168
+ return outDict
169
+
170
+ def run_dcm(smiles="C1NCCCC1"):
171
+ dcm1, dcm2 = create_models()
172
+ dcm1_weights, dcm2_weights = load_weights()
173
+ data_batch, smiles_image = prepare_inputs(smiles)
174
+
175
+ batch_size = 1
176
+ psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
177
+ batch = psi4_test_batches[0]
178
+
179
+ mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
180
+ mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
181
+
182
+ dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size)
183
+ dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size)
184
+
185
+ return {
186
+ "smiles_image": smiles_image,
187
+ "atoms": dcm1_results["atoms"],
188
+ "dcmol": dcm1_results["dcmol"],
189
+ "dcmol2": dcm2_results["dcmol"],
190
+ }
191
+
192
+
193
+ if __name__ == "__main__":
194
+ smiles = "C1NCCCC1"
195
+ results = run_dcm(smiles)
196
+ print(results)
requirements.txt CHANGED
@@ -3,5 +3,5 @@ patchworklib
3
  ase
4
  scipy
5
  rdkit
6
- #tqdm
7
- dcmnet @ git+https://github.com/EricBoittier/jaxeq@main
 
3
  ase
4
  scipy
5
  rdkit
6
+ plotnine
7
+ dcmnet @ git+https://github.com/EricBoittier/dcmnet@main