EricBoi commited on
Commit
543d14b
·
1 Parent(s): 100bebf

Refactor weight loading in dcm_app.py to use pickle for parameter files, replacing pandas with a dedicated function for loading dictionaries.

Browse files
dcm_app.py CHANGED
@@ -2,7 +2,7 @@ import ase
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
- import pandas as pd
6
  from rdkit import Chem
7
  from rdkit.Chem import AllChem
8
  from rdkit.Chem import Draw
@@ -75,9 +75,14 @@ def get_grid_points(coordinates):
75
  return grid_points
76
 
77
 
 
 
 
 
 
78
  def load_weights():
79
- dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl")
80
- dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl")
81
  return dcm1_weights, dcm2_weights
82
 
83
 
 
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
+ import pickle
6
  from rdkit import Chem
7
  from rdkit.Chem import AllChem
8
  from rdkit.Chem import Draw
 
75
  return grid_points
76
 
77
 
78
+ def load_pickle_dict(path):
79
+ with open(path, "rb") as handle:
80
+ return pickle.load(handle)
81
+
82
+
83
  def load_weights():
84
+ dcm1_weights = load_pickle_dict("wbs/best_0.0_params_dict.pkl")
85
+ dcm2_weights = load_pickle_dict("wbs/dcm2-best_1000.0_params_dict.pkl")
86
  return dcm1_weights, dcm2_weights
87
 
88
 
scripts/convert_pkls_to_dicts.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import pickle
4
+ from collections.abc import Mapping
5
+ from pathlib import Path
6
+
7
+
8
+ def as_dict(obj):
9
+ if isinstance(obj, dict):
10
+ return {key: as_dict(value) for key, value in obj.items()}
11
+ if isinstance(obj, Mapping):
12
+ return {key: as_dict(value) for key, value in obj.items()}
13
+ if hasattr(obj, "to_dict"):
14
+ return as_dict(obj.to_dict())
15
+ return obj
16
+
17
+
18
+ def convert_file(source_path, output_path):
19
+ with source_path.open("rb") as handle:
20
+ payload = pickle.load(handle)
21
+ dict_payload = as_dict(payload)
22
+ with output_path.open("wb") as handle:
23
+ pickle.dump(dict_payload, handle, protocol=pickle.HIGHEST_PROTOCOL)
24
+
25
+
26
+ def build_output_path(source_path, output_dir, suffix):
27
+ if output_dir is not None:
28
+ return output_dir / source_path.name
29
+ return source_path.with_name(f"{source_path.stem}{suffix}{source_path.suffix}")
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser(
34
+ description="Convert pickled weights into plain dictionaries."
35
+ )
36
+ parser.add_argument(
37
+ "--input-dir",
38
+ default="wbs",
39
+ help="Directory containing .pkl files (default: wbs).",
40
+ )
41
+ parser.add_argument(
42
+ "--output-dir",
43
+ default=None,
44
+ help="Directory to write dictionary pickles. Defaults next to input.",
45
+ )
46
+ parser.add_argument(
47
+ "--suffix",
48
+ default="_dict",
49
+ help="Suffix to append before extension (default: _dict).",
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ input_dir = Path(args.input_dir)
54
+ if not input_dir.is_dir():
55
+ raise SystemExit(f"Input dir not found: {input_dir}")
56
+
57
+ output_dir = Path(args.output_dir) if args.output_dir else None
58
+ if output_dir is not None:
59
+ output_dir.mkdir(parents=True, exist_ok=True)
60
+
61
+ pkls = sorted(input_dir.glob("*.pkl"))
62
+ if not pkls:
63
+ raise SystemExit(f"No .pkl files found in {input_dir}")
64
+
65
+ for source_path in pkls:
66
+ if source_path.stem.endswith(args.suffix):
67
+ continue
68
+ output_path = build_output_path(source_path, output_dir, args.suffix)
69
+ convert_file(source_path, output_path)
70
+ print(f"{source_path} -> {output_path}")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
wbs/best_0.0_params_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:580eff73f9c8bb5e3d67766bf7d923d8855943ccf0b25fb463f0fc0b65e162a1
3
+ size 21694
wbs/best_1000.0_params_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f859c8c20019e87081e319fb03771c22415dca8b33c5a15163b3a2830249e933
3
+ size 22663
wbs/dcm2-best_1000.0_params_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be8f1a6f7648fc4549a06d32a588a7684f03a6f54dfcc785eefcb6c588f32319
3
+ size 22179
wbs/dcm3-best_1000.0_params_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f859c8c20019e87081e319fb03771c22415dca8b33c5a15163b3a2830249e933
3
+ size 22663
wbs/dcm4-best_1000.0_params_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68ac932b26f93df7f99a7c911e5ed111f6818ddb65d32a56b0110e5f82199933
3
+ size 23158