| | |
| | import argparse |
| | import json |
| | import pickle |
| | from collections.abc import Mapping |
| | from pathlib import Path |
| |
|
| | try: |
| | import numpy as np |
| | except Exception: |
| | np = None |
| |
|
| | try: |
| | import jax.numpy as jnp |
| | except Exception: |
| | jnp = None |
| |
|
| |
|
| | def as_dict(obj): |
| | if isinstance(obj, dict): |
| | return {key: as_dict(value) for key, value in obj.items()} |
| | if isinstance(obj, Mapping): |
| | return {key: as_dict(value) for key, value in obj.items()} |
| | if hasattr(obj, "to_dict"): |
| | return as_dict(obj.to_dict()) |
| | return obj |
| |
|
| |
|
| | def to_jsonable(obj): |
| | if isinstance(obj, dict): |
| | return {key: to_jsonable(value) for key, value in obj.items()} |
| | if isinstance(obj, (list, tuple)): |
| | return [to_jsonable(value) for value in obj] |
| | if np is not None: |
| | if isinstance(obj, np.ndarray): |
| | return obj.tolist() |
| | if isinstance(obj, (np.integer, np.floating)): |
| | return obj.item() |
| | if jnp is not None and isinstance(obj, jnp.ndarray): |
| | return np.array(obj).tolist() if np is not None else obj.tolist() |
| | if hasattr(obj, "tolist"): |
| | return obj.tolist() |
| | return obj |
| |
|
| |
|
| | def convert_file(source_path, output_path): |
| | with source_path.open("rb") as handle: |
| | payload = pickle.load(handle) |
| | dict_payload = as_dict(payload) |
| | json_payload = to_jsonable(dict_payload) |
| | with output_path.open("w", encoding="utf-8") as handle: |
| | json.dump(json_payload, handle, indent=2, sort_keys=True) |
| |
|
| |
|
| | def build_output_path(source_path, output_dir, suffix): |
| | if output_dir is not None: |
| | return output_dir / f"{source_path.stem}{suffix}.json" |
| | return source_path.with_name(f"{source_path.stem}{suffix}.json") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Convert pickled weights into plain dictionaries." |
| | ) |
| | parser.add_argument( |
| | "--input-dir", |
| | default="wbs", |
| | help="Directory containing .pkl files (default: wbs).", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", |
| | default=None, |
| | help="Directory to write dictionary pickles. Defaults next to input.", |
| | ) |
| | parser.add_argument( |
| | "--suffix", |
| | default="_dict", |
| | help="Suffix to append before extension (default: _dict).", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | input_dir = Path(args.input_dir) |
| | if not input_dir.is_dir(): |
| | raise SystemExit(f"Input dir not found: {input_dir}") |
| |
|
| | output_dir = Path(args.output_dir) if args.output_dir else None |
| | if output_dir is not None: |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | pkls = sorted(input_dir.glob("*.pkl")) |
| | if not pkls: |
| | raise SystemExit(f"No .pkl files found in {input_dir}") |
| |
|
| | for source_path in pkls: |
| | if source_path.stem.endswith(args.suffix): |
| | continue |
| | output_path = build_output_path(source_path, output_dir, args.suffix) |
| | convert_file(source_path, output_path) |
| | print(f"{source_path} -> {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|