EricBoi commited on
Commit
0b9d81c
·
1 Parent(s): 543d14b

Refactor weight loading in dcm_app.py to use JSON instead of pickle, adding a new function for restoring arrays from loaded data.

Browse files
Files changed (2) hide show
  1. dcm_app.py +21 -6
  2. scripts/convert_pkls_to_dicts.py +33 -4
dcm_app.py CHANGED
@@ -2,7 +2,7 @@ import ase
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
- import pickle
6
  from rdkit import Chem
7
  from rdkit.Chem import AllChem
8
  from rdkit.Chem import Draw
@@ -75,14 +75,29 @@ def get_grid_points(coordinates):
75
  return grid_points
76
 
77
 
78
- def load_pickle_dict(path):
79
- with open(path, "rb") as handle:
80
- return pickle.load(handle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  def load_weights():
84
- dcm1_weights = load_pickle_dict("wbs/best_0.0_params_dict.pkl")
85
- dcm2_weights = load_pickle_dict("wbs/dcm2-best_1000.0_params_dict.pkl")
86
  return dcm1_weights, dcm2_weights
87
 
88
 
 
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
+ import json
6
  from rdkit import Chem
7
  from rdkit.Chem import AllChem
8
  from rdkit.Chem import Draw
 
75
  return grid_points
76
 
77
 
78
+ def restore_arrays(obj):
79
+ if isinstance(obj, dict):
80
+ return {key: restore_arrays(value) for key, value in obj.items()}
81
+ if isinstance(obj, list):
82
+ restored = [restore_arrays(value) for value in obj]
83
+ if any(isinstance(value, dict) for value in restored):
84
+ return restored
85
+ try:
86
+ return np.array(restored)
87
+ except Exception:
88
+ return restored
89
+ return obj
90
+
91
+
92
+ def load_json_dict(path):
93
+ with open(path, "r", encoding="utf-8") as handle:
94
+ payload = json.load(handle)
95
+ return restore_arrays(payload)
96
 
97
 
98
  def load_weights():
99
+ dcm1_weights = load_json_dict("wbs/best_0.0_params.json")
100
+ dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params.json")
101
  return dcm1_weights, dcm2_weights
102
 
103
 
scripts/convert_pkls_to_dicts.py CHANGED
@@ -1,9 +1,20 @@
1
  #!/usr/bin/env python3
2
  import argparse
 
3
  import pickle
4
  from collections.abc import Mapping
5
  from pathlib import Path
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def as_dict(obj):
9
  if isinstance(obj, dict):
@@ -15,18 +26,36 @@ def as_dict(obj):
15
  return obj
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def convert_file(source_path, output_path):
19
  with source_path.open("rb") as handle:
20
  payload = pickle.load(handle)
21
  dict_payload = as_dict(payload)
22
- with output_path.open("wb") as handle:
23
- pickle.dump(dict_payload, handle, protocol=pickle.HIGHEST_PROTOCOL)
 
24
 
25
 
26
  def build_output_path(source_path, output_dir, suffix):
27
  if output_dir is not None:
28
- return output_dir / source_path.name
29
- return source_path.with_name(f"{source_path.stem}{suffix}{source_path.suffix}")
30
 
31
 
32
  def main():
 
1
  #!/usr/bin/env python3
2
  import argparse
3
+ import json
4
  import pickle
5
  from collections.abc import Mapping
6
  from pathlib import Path
7
 
8
+ try:
9
+ import numpy as np
10
+ except Exception: # pragma: no cover - optional dependency for conversion
11
+ np = None
12
+
13
+ try:
14
+ import jax.numpy as jnp
15
+ except Exception: # pragma: no cover - optional dependency for conversion
16
+ jnp = None
17
+
18
 
19
  def as_dict(obj):
20
  if isinstance(obj, dict):
 
26
  return obj
27
 
28
 
29
+ def to_jsonable(obj):
30
+ if isinstance(obj, dict):
31
+ return {key: to_jsonable(value) for key, value in obj.items()}
32
+ if isinstance(obj, (list, tuple)):
33
+ return [to_jsonable(value) for value in obj]
34
+ if np is not None:
35
+ if isinstance(obj, np.ndarray):
36
+ return obj.tolist()
37
+ if isinstance(obj, (np.integer, np.floating)):
38
+ return obj.item()
39
+ if jnp is not None and isinstance(obj, jnp.ndarray):
40
+ return np.array(obj).tolist() if np is not None else obj.tolist()
41
+ if hasattr(obj, "tolist"):
42
+ return obj.tolist()
43
+ return obj
44
+
45
+
46
  def convert_file(source_path, output_path):
47
  with source_path.open("rb") as handle:
48
  payload = pickle.load(handle)
49
  dict_payload = as_dict(payload)
50
+ json_payload = to_jsonable(dict_payload)
51
+ with output_path.open("w", encoding="utf-8") as handle:
52
+ json.dump(json_payload, handle, indent=2, sort_keys=True)
53
 
54
 
55
  def build_output_path(source_path, output_dir, suffix):
56
  if output_dir is not None:
57
+ return output_dir / f"{source_path.stem}{suffix}.json"
58
+ return source_path.with_name(f"{source_path.stem}{suffix}.json")
59
 
60
 
61
  def main():