EricBoi commited on
Commit
100bebf
·
1 Parent(s): 88cd151

Enhance DCM evaluation by adding normalization for batch inputs and updating the evaluation function to accept a variable number of DCMs; include print statement for results in app.py.

Browse files
__pycache__/dcm_app.cpython-310.pyc CHANGED
Binary files a/__pycache__/dcm_app.cpython-310.pyc and b/__pycache__/dcm_app.cpython-310.pyc differ
 
__pycache__/dcm_app.cpython-313.pyc CHANGED
Binary files a/__pycache__/dcm_app.cpython-313.pyc and b/__pycache__/dcm_app.cpython-313.pyc differ
 
app.py CHANGED
@@ -11,7 +11,7 @@ from dcm_app import run_dcm
11
 
12
  smiles = "C1NCCCC1"
13
  results = run_dcm(smiles)
14
-
15
  st.image(results["smiles_image"])
16
  st.write("Click M to see the distributed charges")
17
 
 
11
 
12
  smiles = "C1NCCCC1"
13
  results = run_dcm(smiles)
14
+ print(results)
15
  st.image(results["smiles_image"])
16
  st.write("Click M to see the distributed charges")
17
 
dcm_app.py CHANGED
@@ -138,19 +138,28 @@ def prepare_inputs(smiles):
138
 
139
  return data_batch, smiles_image
140
 
141
- def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
142
  esp_errors, mono_pred, _, _ = evaluate_dc(
143
  batch,
144
  dipo_dc1,
145
  mono_dc1,
146
  batch_size,
147
- 1,
148
  plot=False,
149
 
150
  )
151
 
152
- atoms, dcmol, grid, esp, esp_dc_pred, idx_cut = create_plots2(
153
- mono_dc1, dipo_dc1, batch, batch_size, 1
 
 
 
 
 
 
 
 
 
154
  )
155
  outDict = {
156
  "mono": mono_dc1,
@@ -158,15 +167,30 @@ def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
158
  "esp_errors": esp_errors,
159
  "atoms": atoms,
160
  "dcmol": dcmol,
161
- "grid": grid,
162
- "esp": esp,
163
- "esp_dc_pred": esp_dc_pred,
164
  "esp_mono_pred": mono_pred,
165
- "idx_cut": idx_cut,
166
  }
167
 
168
  return outDict
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def run_dcm(smiles="C1NCCCC1"):
171
  dcm1, dcm2 = create_models()
172
  dcm1_weights, dcm2_weights = load_weights()
@@ -174,13 +198,13 @@ def run_dcm(smiles="C1NCCCC1"):
174
 
175
  batch_size = 1
176
  psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
177
- batch = psi4_test_batches[0]
178
 
179
  mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
180
  mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
181
 
182
- dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size)
183
- dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size)
184
 
185
  return {
186
  "smiles_image": smiles_image,
@@ -190,7 +214,7 @@ def run_dcm(smiles="C1NCCCC1"):
190
  }
191
 
192
 
193
- if __name__ == "__main__":
194
- smiles = "C1NCCCC1"
195
- results = run_dcm(smiles)
196
- print(results)
 
138
 
139
  return data_batch, smiles_image
140
 
141
+ def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm):
142
  esp_errors, mono_pred, _, _ = evaluate_dc(
143
  batch,
144
  dipo_dc1,
145
  mono_dc1,
146
  batch_size,
147
+ n_dcm,
148
  plot=False,
149
 
150
  )
151
 
152
+ n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0])
153
+ n_dcm = mono_dc1.shape[-1]
154
+ atoms = ase.Atoms(
155
+ numbers=np.array(batch["Z"][:n_atoms]),
156
+ positions=np.array(batch["R"][:n_atoms]),
157
+ )
158
+ dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm]
159
+ dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm]
160
+ dcmol = ase.Atoms(
161
+ ["X" if _ > 0 else "He" for _ in dcm_charges],
162
+ dcm_positions,
163
  )
164
  outDict = {
165
  "mono": mono_dc1,
 
167
  "esp_errors": esp_errors,
168
  "atoms": atoms,
169
  "dcmol": dcmol,
170
+ "grid": None,
171
+ "esp": None,
172
+ "esp_dc_pred": None,
173
  "esp_mono_pred": mono_pred,
174
+ "idx_cut": None,
175
  }
176
 
177
  return outDict
178
 
179
+ def normalize_batch(batch):
180
+ vdw_surface = batch.get("vdw_surface")
181
+ if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1:
182
+ batch["vdw_surface"] = vdw_surface.squeeze(axis=1)
183
+
184
+ esp = batch.get("esp")
185
+ if esp is not None and esp.ndim == 3 and esp.shape[1] == 1:
186
+ batch["esp"] = esp.squeeze(axis=1)
187
+
188
+ esp_mask = batch.get("espMask")
189
+ if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1:
190
+ batch["espMask"] = esp_mask.squeeze(axis=1)
191
+
192
+ return batch
193
+
194
  def run_dcm(smiles="C1NCCCC1"):
195
  dcm1, dcm2 = create_models()
196
  dcm1_weights, dcm2_weights = load_weights()
 
198
 
199
  batch_size = 1
200
  psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
201
+ batch = normalize_batch(psi4_test_batches[0])
202
 
203
  mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
204
  mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
205
 
206
+ dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
207
+ dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
208
 
209
  return {
210
  "smiles_image": smiles_image,
 
214
  }
215
 
216
 
217
+ # if __name__ == "__main__":
218
+ # smiles = "C1NCCCC1"
219
+ # results = run_dcm(smiles)
220
+ # print(results)