Refactor weight loading in dcm_app.py to use JSON instead of pickle, adding a new function for restoring arrays from loaded data.
Browse files- dcm_app.py +21 -6
- scripts/convert_pkls_to_dicts.py +33 -4
dcm_app.py
CHANGED
|
@@ -2,7 +2,7 @@ import ase
|
|
| 2 |
import jax
|
| 3 |
import jax.numpy as jnp
|
| 4 |
import numpy as np
|
| 5 |
-
import
|
| 6 |
from rdkit import Chem
|
| 7 |
from rdkit.Chem import AllChem
|
| 8 |
from rdkit.Chem import Draw
|
|
@@ -75,14 +75,29 @@ def get_grid_points(coordinates):
|
|
| 75 |
return grid_points
|
| 76 |
|
| 77 |
|
| 78 |
-
def
|
| 79 |
-
|
| 80 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
def load_weights():
|
| 84 |
-
dcm1_weights =
|
| 85 |
-
dcm2_weights =
|
| 86 |
return dcm1_weights, dcm2_weights
|
| 87 |
|
| 88 |
|
|
|
|
| 2 |
import jax
|
| 3 |
import jax.numpy as jnp
|
| 4 |
import numpy as np
|
| 5 |
+
import json
|
| 6 |
from rdkit import Chem
|
| 7 |
from rdkit.Chem import AllChem
|
| 8 |
from rdkit.Chem import Draw
|
|
|
|
| 75 |
return grid_points
|
| 76 |
|
| 77 |
|
| 78 |
+
def restore_arrays(obj):
|
| 79 |
+
if isinstance(obj, dict):
|
| 80 |
+
return {key: restore_arrays(value) for key, value in obj.items()}
|
| 81 |
+
if isinstance(obj, list):
|
| 82 |
+
restored = [restore_arrays(value) for value in obj]
|
| 83 |
+
if any(isinstance(value, dict) for value in restored):
|
| 84 |
+
return restored
|
| 85 |
+
try:
|
| 86 |
+
return np.array(restored)
|
| 87 |
+
except Exception:
|
| 88 |
+
return restored
|
| 89 |
+
return obj
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_json_dict(path):
|
| 93 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 94 |
+
payload = json.load(handle)
|
| 95 |
+
return restore_arrays(payload)
|
| 96 |
|
| 97 |
|
| 98 |
def load_weights():
|
| 99 |
+
dcm1_weights = load_json_dict("wbs/best_0.0_params.json")
|
| 100 |
+
dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params.json")
|
| 101 |
return dcm1_weights, dcm2_weights
|
| 102 |
|
| 103 |
|
scripts/convert_pkls_to_dicts.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import argparse
|
|
|
|
| 3 |
import pickle
|
| 4 |
from collections.abc import Mapping
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def as_dict(obj):
|
| 9 |
if isinstance(obj, dict):
|
|
@@ -15,18 +26,36 @@ def as_dict(obj):
|
|
| 15 |
return obj
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def convert_file(source_path, output_path):
|
| 19 |
with source_path.open("rb") as handle:
|
| 20 |
payload = pickle.load(handle)
|
| 21 |
dict_payload = as_dict(payload)
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def build_output_path(source_path, output_dir, suffix):
|
| 27 |
if output_dir is not None:
|
| 28 |
-
return output_dir / source_path.
|
| 29 |
-
return source_path.with_name(f"{source_path.stem}{suffix}
|
| 30 |
|
| 31 |
|
| 32 |
def main():
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
import argparse
|
| 3 |
+
import json
|
| 4 |
import pickle
|
| 5 |
from collections.abc import Mapping
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
+
try:
|
| 9 |
+
import numpy as np
|
| 10 |
+
except Exception: # pragma: no cover - optional dependency for conversion
|
| 11 |
+
np = None
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import jax.numpy as jnp
|
| 15 |
+
except Exception: # pragma: no cover - optional dependency for conversion
|
| 16 |
+
jnp = None
|
| 17 |
+
|
| 18 |
|
| 19 |
def as_dict(obj):
|
| 20 |
if isinstance(obj, dict):
|
|
|
|
| 26 |
return obj
|
| 27 |
|
| 28 |
|
| 29 |
+
def to_jsonable(obj):
|
| 30 |
+
if isinstance(obj, dict):
|
| 31 |
+
return {key: to_jsonable(value) for key, value in obj.items()}
|
| 32 |
+
if isinstance(obj, (list, tuple)):
|
| 33 |
+
return [to_jsonable(value) for value in obj]
|
| 34 |
+
if np is not None:
|
| 35 |
+
if isinstance(obj, np.ndarray):
|
| 36 |
+
return obj.tolist()
|
| 37 |
+
if isinstance(obj, (np.integer, np.floating)):
|
| 38 |
+
return obj.item()
|
| 39 |
+
if jnp is not None and isinstance(obj, jnp.ndarray):
|
| 40 |
+
return np.array(obj).tolist() if np is not None else obj.tolist()
|
| 41 |
+
if hasattr(obj, "tolist"):
|
| 42 |
+
return obj.tolist()
|
| 43 |
+
return obj
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def convert_file(source_path, output_path):
|
| 47 |
with source_path.open("rb") as handle:
|
| 48 |
payload = pickle.load(handle)
|
| 49 |
dict_payload = as_dict(payload)
|
| 50 |
+
json_payload = to_jsonable(dict_payload)
|
| 51 |
+
with output_path.open("w", encoding="utf-8") as handle:
|
| 52 |
+
json.dump(json_payload, handle, indent=2, sort_keys=True)
|
| 53 |
|
| 54 |
|
| 55 |
def build_output_path(source_path, output_dir, suffix):
|
| 56 |
if output_dir is not None:
|
| 57 |
+
return output_dir / f"{source_path.stem}{suffix}.json"
|
| 58 |
+
return source_path.with_name(f"{source_path.stem}{suffix}.json")
|
| 59 |
|
| 60 |
|
| 61 |
def main():
|