DCMNet / scripts /convert_pkls_to_dicts.py
EricBoi's picture
Refactor weight loading in dcm_app.py to use JSON instead of pickle, adding a new function for restoring arrays from loaded data.
0b9d81c
#!/usr/bin/env python3
import argparse
import json
import pickle
from collections.abc import Mapping
from pathlib import Path
try:
import numpy as np
except Exception: # pragma: no cover - optional dependency for conversion
np = None
try:
import jax.numpy as jnp
except Exception: # pragma: no cover - optional dependency for conversion
jnp = None
def as_dict(obj):
if isinstance(obj, dict):
return {key: as_dict(value) for key, value in obj.items()}
if isinstance(obj, Mapping):
return {key: as_dict(value) for key, value in obj.items()}
if hasattr(obj, "to_dict"):
return as_dict(obj.to_dict())
return obj
def to_jsonable(obj):
if isinstance(obj, dict):
return {key: to_jsonable(value) for key, value in obj.items()}
if isinstance(obj, (list, tuple)):
return [to_jsonable(value) for value in obj]
if np is not None:
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating)):
return obj.item()
if jnp is not None and isinstance(obj, jnp.ndarray):
return np.array(obj).tolist() if np is not None else obj.tolist()
if hasattr(obj, "tolist"):
return obj.tolist()
return obj
def convert_file(source_path, output_path):
with source_path.open("rb") as handle:
payload = pickle.load(handle)
dict_payload = as_dict(payload)
json_payload = to_jsonable(dict_payload)
with output_path.open("w", encoding="utf-8") as handle:
json.dump(json_payload, handle, indent=2, sort_keys=True)
def build_output_path(source_path, output_dir, suffix):
if output_dir is not None:
return output_dir / f"{source_path.stem}{suffix}.json"
return source_path.with_name(f"{source_path.stem}{suffix}.json")
def main():
parser = argparse.ArgumentParser(
description="Convert pickled weights into plain dictionaries."
)
parser.add_argument(
"--input-dir",
default="wbs",
help="Directory containing .pkl files (default: wbs).",
)
parser.add_argument(
"--output-dir",
default=None,
help="Directory to write dictionary pickles. Defaults next to input.",
)
parser.add_argument(
"--suffix",
default="_dict",
help="Suffix to append before extension (default: _dict).",
)
args = parser.parse_args()
input_dir = Path(args.input_dir)
if not input_dir.is_dir():
raise SystemExit(f"Input dir not found: {input_dir}")
output_dir = Path(args.output_dir) if args.output_dir else None
if output_dir is not None:
output_dir.mkdir(parents=True, exist_ok=True)
pkls = sorted(input_dir.glob("*.pkl"))
if not pkls:
raise SystemExit(f"No .pkl files found in {input_dir}")
for source_path in pkls:
if source_path.stem.endswith(args.suffix):
continue
output_path = build_output_path(source_path, output_dir, args.suffix)
convert_file(source_path, output_path)
print(f"{source_path} -> {output_path}")
if __name__ == "__main__":
main()