Enhance DCM evaluation by adding normalization for batch inputs and updating the evaluation function to accept a variable number of DCMs; include print statement for results in app.py.
Browse files- __pycache__/dcm_app.cpython-310.pyc +0 -0
- __pycache__/dcm_app.cpython-313.pyc +0 -0
- app.py +1 -1
- dcm_app.py +39 -15
__pycache__/dcm_app.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/dcm_app.cpython-310.pyc and b/__pycache__/dcm_app.cpython-310.pyc differ
|
|
|
__pycache__/dcm_app.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/dcm_app.cpython-313.pyc and b/__pycache__/dcm_app.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from dcm_app import run_dcm
|
|
| 11 |
|
| 12 |
smiles = "C1NCCCC1"
|
| 13 |
results = run_dcm(smiles)
|
| 14 |
-
|
| 15 |
st.image(results["smiles_image"])
|
| 16 |
st.write("Click M to see the distributed charges")
|
| 17 |
|
|
|
|
| 11 |
|
| 12 |
smiles = "C1NCCCC1"
|
| 13 |
results = run_dcm(smiles)
|
| 14 |
+
print(results)
|
| 15 |
st.image(results["smiles_image"])
|
| 16 |
st.write("Click M to see the distributed charges")
|
| 17 |
|
dcm_app.py
CHANGED
|
@@ -138,19 +138,28 @@ def prepare_inputs(smiles):
|
|
| 138 |
|
| 139 |
return data_batch, smiles_image
|
| 140 |
|
| 141 |
-
def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
|
| 142 |
esp_errors, mono_pred, _, _ = evaluate_dc(
|
| 143 |
batch,
|
| 144 |
dipo_dc1,
|
| 145 |
mono_dc1,
|
| 146 |
batch_size,
|
| 147 |
-
|
| 148 |
plot=False,
|
| 149 |
|
| 150 |
)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
)
|
| 155 |
outDict = {
|
| 156 |
"mono": mono_dc1,
|
|
@@ -158,15 +167,30 @@ def do_eval(batch, dipo_dc1, mono_dc1, batch_size):
|
|
| 158 |
"esp_errors": esp_errors,
|
| 159 |
"atoms": atoms,
|
| 160 |
"dcmol": dcmol,
|
| 161 |
-
"grid":
|
| 162 |
-
"esp":
|
| 163 |
-
"esp_dc_pred":
|
| 164 |
"esp_mono_pred": mono_pred,
|
| 165 |
-
"idx_cut":
|
| 166 |
}
|
| 167 |
|
| 168 |
return outDict
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
def run_dcm(smiles="C1NCCCC1"):
|
| 171 |
dcm1, dcm2 = create_models()
|
| 172 |
dcm1_weights, dcm2_weights = load_weights()
|
|
@@ -174,13 +198,13 @@ def run_dcm(smiles="C1NCCCC1"):
|
|
| 174 |
|
| 175 |
batch_size = 1
|
| 176 |
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
|
| 177 |
-
batch = psi4_test_batches[0]
|
| 178 |
|
| 179 |
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
|
| 180 |
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
|
| 181 |
|
| 182 |
-
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size)
|
| 183 |
-
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size)
|
| 184 |
|
| 185 |
return {
|
| 186 |
"smiles_image": smiles_image,
|
|
@@ -190,7 +214,7 @@ def run_dcm(smiles="C1NCCCC1"):
|
|
| 190 |
}
|
| 191 |
|
| 192 |
|
| 193 |
-
if __name__ == "__main__":
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
| 138 |
|
| 139 |
return data_batch, smiles_image
|
| 140 |
|
| 141 |
+
def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm):
|
| 142 |
esp_errors, mono_pred, _, _ = evaluate_dc(
|
| 143 |
batch,
|
| 144 |
dipo_dc1,
|
| 145 |
mono_dc1,
|
| 146 |
batch_size,
|
| 147 |
+
n_dcm,
|
| 148 |
plot=False,
|
| 149 |
|
| 150 |
)
|
| 151 |
|
| 152 |
+
n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0])
|
| 153 |
+
n_dcm = mono_dc1.shape[-1]
|
| 154 |
+
atoms = ase.Atoms(
|
| 155 |
+
numbers=np.array(batch["Z"][:n_atoms]),
|
| 156 |
+
positions=np.array(batch["R"][:n_atoms]),
|
| 157 |
+
)
|
| 158 |
+
dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm]
|
| 159 |
+
dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm]
|
| 160 |
+
dcmol = ase.Atoms(
|
| 161 |
+
["X" if _ > 0 else "He" for _ in dcm_charges],
|
| 162 |
+
dcm_positions,
|
| 163 |
)
|
| 164 |
outDict = {
|
| 165 |
"mono": mono_dc1,
|
|
|
|
| 167 |
"esp_errors": esp_errors,
|
| 168 |
"atoms": atoms,
|
| 169 |
"dcmol": dcmol,
|
| 170 |
+
"grid": None,
|
| 171 |
+
"esp": None,
|
| 172 |
+
"esp_dc_pred": None,
|
| 173 |
"esp_mono_pred": mono_pred,
|
| 174 |
+
"idx_cut": None,
|
| 175 |
}
|
| 176 |
|
| 177 |
return outDict
|
| 178 |
|
| 179 |
+
def normalize_batch(batch):
|
| 180 |
+
vdw_surface = batch.get("vdw_surface")
|
| 181 |
+
if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1:
|
| 182 |
+
batch["vdw_surface"] = vdw_surface.squeeze(axis=1)
|
| 183 |
+
|
| 184 |
+
esp = batch.get("esp")
|
| 185 |
+
if esp is not None and esp.ndim == 3 and esp.shape[1] == 1:
|
| 186 |
+
batch["esp"] = esp.squeeze(axis=1)
|
| 187 |
+
|
| 188 |
+
esp_mask = batch.get("espMask")
|
| 189 |
+
if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1:
|
| 190 |
+
batch["espMask"] = esp_mask.squeeze(axis=1)
|
| 191 |
+
|
| 192 |
+
return batch
|
| 193 |
+
|
| 194 |
def run_dcm(smiles="C1NCCCC1"):
|
| 195 |
dcm1, dcm2 = create_models()
|
| 196 |
dcm1_weights, dcm2_weights = load_weights()
|
|
|
|
| 198 |
|
| 199 |
batch_size = 1
|
| 200 |
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
|
| 201 |
+
batch = normalize_batch(psi4_test_batches[0])
|
| 202 |
|
| 203 |
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
|
| 204 |
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
|
| 205 |
|
| 206 |
+
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
|
| 207 |
+
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
|
| 208 |
|
| 209 |
return {
|
| 210 |
"smiles_image": smiles_image,
|
|
|
|
| 214 |
}
|
| 215 |
|
| 216 |
|
| 217 |
+
# if __name__ == "__main__":
|
| 218 |
+
# smiles = "C1NCCCC1"
|
| 219 |
+
# results = run_dcm(smiles)
|
| 220 |
+
# print(results)
|