#!/usr/bin/env python3 import argparse import json import pickle from collections.abc import Mapping from pathlib import Path try: import numpy as np except Exception: # pragma: no cover - optional dependency for conversion np = None try: import jax.numpy as jnp except Exception: # pragma: no cover - optional dependency for conversion jnp = None def as_dict(obj): if isinstance(obj, dict): return {key: as_dict(value) for key, value in obj.items()} if isinstance(obj, Mapping): return {key: as_dict(value) for key, value in obj.items()} if hasattr(obj, "to_dict"): return as_dict(obj.to_dict()) return obj def to_jsonable(obj): if isinstance(obj, dict): return {key: to_jsonable(value) for key, value in obj.items()} if isinstance(obj, (list, tuple)): return [to_jsonable(value) for value in obj] if np is not None: if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, (np.integer, np.floating)): return obj.item() if jnp is not None and isinstance(obj, jnp.ndarray): return np.array(obj).tolist() if np is not None else obj.tolist() if hasattr(obj, "tolist"): return obj.tolist() return obj def convert_file(source_path, output_path): with source_path.open("rb") as handle: payload = pickle.load(handle) dict_payload = as_dict(payload) json_payload = to_jsonable(dict_payload) with output_path.open("w", encoding="utf-8") as handle: json.dump(json_payload, handle, indent=2, sort_keys=True) def build_output_path(source_path, output_dir, suffix): if output_dir is not None: return output_dir / f"{source_path.stem}{suffix}.json" return source_path.with_name(f"{source_path.stem}{suffix}.json") def main(): parser = argparse.ArgumentParser( description="Convert pickled weights into plain dictionaries." ) parser.add_argument( "--input-dir", default="wbs", help="Directory containing .pkl files (default: wbs).", ) parser.add_argument( "--output-dir", default=None, help="Directory to write dictionary pickles. Defaults next to input.", ) parser.add_argument( "--suffix", default="_dict", help="Suffix to append before extension (default: _dict).", ) args = parser.parse_args() input_dir = Path(args.input_dir) if not input_dir.is_dir(): raise SystemExit(f"Input dir not found: {input_dir}") output_dir = Path(args.output_dir) if args.output_dir else None if output_dir is not None: output_dir.mkdir(parents=True, exist_ok=True) pkls = sorted(input_dir.glob("*.pkl")) if not pkls: raise SystemExit(f"No .pkl files found in {input_dir}") for source_path in pkls: if source_path.stem.endswith(args.suffix): continue output_path = build_output_path(source_path, output_dir, args.suffix) convert_file(source_path, output_path) print(f"{source_path} -> {output_path}") if __name__ == "__main__": main()