Refactor weight loading in dcm_app.py to use pickle for parameter files, replacing pandas with a dedicated function for loading dictionaries.
Browse files
dcm_app.py
CHANGED
|
@@ -2,7 +2,7 @@ import ase
|
|
| 2 |
import jax
|
| 3 |
import jax.numpy as jnp
|
| 4 |
import numpy as np
|
| 5 |
-
import
|
| 6 |
from rdkit import Chem
|
| 7 |
from rdkit.Chem import AllChem
|
| 8 |
from rdkit.Chem import Draw
|
|
@@ -75,9 +75,14 @@ def get_grid_points(coordinates):
|
|
| 75 |
return grid_points
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def load_weights():
|
| 79 |
-
dcm1_weights =
|
| 80 |
-
dcm2_weights =
|
| 81 |
return dcm1_weights, dcm2_weights
|
| 82 |
|
| 83 |
|
|
|
|
| 2 |
import jax
|
| 3 |
import jax.numpy as jnp
|
| 4 |
import numpy as np
|
| 5 |
+
import pickle
|
| 6 |
from rdkit import Chem
|
| 7 |
from rdkit.Chem import AllChem
|
| 8 |
from rdkit.Chem import Draw
|
|
|
|
| 75 |
return grid_points
|
| 76 |
|
| 77 |
|
| 78 |
+
def load_pickle_dict(path):
|
| 79 |
+
with open(path, "rb") as handle:
|
| 80 |
+
return pickle.load(handle)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
def load_weights():
|
| 84 |
+
dcm1_weights = load_pickle_dict("wbs/best_0.0_params_dict.pkl")
|
| 85 |
+
dcm2_weights = load_pickle_dict("wbs/dcm2-best_1000.0_params_dict.pkl")
|
| 86 |
return dcm1_weights, dcm2_weights
|
| 87 |
|
| 88 |
|
scripts/convert_pkls_to_dicts.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import pickle
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def as_dict(obj):
|
| 9 |
+
if isinstance(obj, dict):
|
| 10 |
+
return {key: as_dict(value) for key, value in obj.items()}
|
| 11 |
+
if isinstance(obj, Mapping):
|
| 12 |
+
return {key: as_dict(value) for key, value in obj.items()}
|
| 13 |
+
if hasattr(obj, "to_dict"):
|
| 14 |
+
return as_dict(obj.to_dict())
|
| 15 |
+
return obj
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def convert_file(source_path, output_path):
|
| 19 |
+
with source_path.open("rb") as handle:
|
| 20 |
+
payload = pickle.load(handle)
|
| 21 |
+
dict_payload = as_dict(payload)
|
| 22 |
+
with output_path.open("wb") as handle:
|
| 23 |
+
pickle.dump(dict_payload, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def build_output_path(source_path, output_dir, suffix):
|
| 27 |
+
if output_dir is not None:
|
| 28 |
+
return output_dir / source_path.name
|
| 29 |
+
return source_path.with_name(f"{source_path.stem}{suffix}{source_path.suffix}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
parser = argparse.ArgumentParser(
|
| 34 |
+
description="Convert pickled weights into plain dictionaries."
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--input-dir",
|
| 38 |
+
default="wbs",
|
| 39 |
+
help="Directory containing .pkl files (default: wbs).",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--output-dir",
|
| 43 |
+
default=None,
|
| 44 |
+
help="Directory to write dictionary pickles. Defaults next to input.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--suffix",
|
| 48 |
+
default="_dict",
|
| 49 |
+
help="Suffix to append before extension (default: _dict).",
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
input_dir = Path(args.input_dir)
|
| 54 |
+
if not input_dir.is_dir():
|
| 55 |
+
raise SystemExit(f"Input dir not found: {input_dir}")
|
| 56 |
+
|
| 57 |
+
output_dir = Path(args.output_dir) if args.output_dir else None
|
| 58 |
+
if output_dir is not None:
|
| 59 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
|
| 61 |
+
pkls = sorted(input_dir.glob("*.pkl"))
|
| 62 |
+
if not pkls:
|
| 63 |
+
raise SystemExit(f"No .pkl files found in {input_dir}")
|
| 64 |
+
|
| 65 |
+
for source_path in pkls:
|
| 66 |
+
if source_path.stem.endswith(args.suffix):
|
| 67 |
+
continue
|
| 68 |
+
output_path = build_output_path(source_path, output_dir, args.suffix)
|
| 69 |
+
convert_file(source_path, output_path)
|
| 70 |
+
print(f"{source_path} -> {output_path}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
wbs/best_0.0_params_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:580eff73f9c8bb5e3d67766bf7d923d8855943ccf0b25fb463f0fc0b65e162a1
|
| 3 |
+
size 21694
|
wbs/best_1000.0_params_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f859c8c20019e87081e319fb03771c22415dca8b33c5a15163b3a2830249e933
|
| 3 |
+
size 22663
|
wbs/dcm2-best_1000.0_params_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be8f1a6f7648fc4549a06d32a588a7684f03a6f54dfcc785eefcb6c588f32319
|
| 3 |
+
size 22179
|
wbs/dcm3-best_1000.0_params_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f859c8c20019e87081e319fb03771c22415dca8b33c5a15163b3a2830249e933
|
| 3 |
+
size 22663
|
wbs/dcm4-best_1000.0_params_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68ac932b26f93df7f99a7c911e5ed111f6818ddb65d32a56b0110e5f82199933
|
| 3 |
+
size 23158
|