File size: 3,154 Bytes
543d14b
 
0b9d81c
543d14b
 
 
 
0b9d81c
 
 
 
 
 
 
 
 
 
543d14b
 
 
 
 
 
 
 
 
 
 
0b9d81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543d14b
 
 
 
0b9d81c
 
 
543d14b
 
 
 
0b9d81c
 
543d14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
import argparse
import json
import pickle
from collections.abc import Mapping
from pathlib import Path

try:
    import numpy as np
except Exception:  # pragma: no cover - optional dependency for conversion
    np = None

try:
    import jax.numpy as jnp
except Exception:  # pragma: no cover - optional dependency for conversion
    jnp = None


def as_dict(obj):
    if isinstance(obj, dict):
        return {key: as_dict(value) for key, value in obj.items()}
    if isinstance(obj, Mapping):
        return {key: as_dict(value) for key, value in obj.items()}
    if hasattr(obj, "to_dict"):
        return as_dict(obj.to_dict())
    return obj


def to_jsonable(obj):
    if isinstance(obj, dict):
        return {key: to_jsonable(value) for key, value in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_jsonable(value) for value in obj]
    if np is not None:
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, (np.integer, np.floating)):
            return obj.item()
    if jnp is not None and isinstance(obj, jnp.ndarray):
        return np.array(obj).tolist() if np is not None else obj.tolist()
    if hasattr(obj, "tolist"):
        return obj.tolist()
    return obj


def convert_file(source_path, output_path):
    with source_path.open("rb") as handle:
        payload = pickle.load(handle)
    dict_payload = as_dict(payload)
    json_payload = to_jsonable(dict_payload)
    with output_path.open("w", encoding="utf-8") as handle:
        json.dump(json_payload, handle, indent=2, sort_keys=True)


def build_output_path(source_path, output_dir, suffix):
    if output_dir is not None:
        return output_dir / f"{source_path.stem}{suffix}.json"
    return source_path.with_name(f"{source_path.stem}{suffix}.json")


def main():
    parser = argparse.ArgumentParser(
        description="Convert pickled weights into plain dictionaries."
    )
    parser.add_argument(
        "--input-dir",
        default="wbs",
        help="Directory containing .pkl files (default: wbs).",
    )
    parser.add_argument(
        "--output-dir",
        default=None,
        help="Directory to write dictionary pickles. Defaults next to input.",
    )
    parser.add_argument(
        "--suffix",
        default="_dict",
        help="Suffix to append before extension (default: _dict).",
    )
    args = parser.parse_args()

    input_dir = Path(args.input_dir)
    if not input_dir.is_dir():
        raise SystemExit(f"Input dir not found: {input_dir}")

    output_dir = Path(args.output_dir) if args.output_dir else None
    if output_dir is not None:
        output_dir.mkdir(parents=True, exist_ok=True)

    pkls = sorted(input_dir.glob("*.pkl"))
    if not pkls:
        raise SystemExit(f"No .pkl files found in {input_dir}")

    for source_path in pkls:
        if source_path.stem.endswith(args.suffix):
            continue
        output_path = build_output_path(source_path, output_dir, args.suffix)
        convert_file(source_path, output_path)
        print(f"{source_path} -> {output_path}")


if __name__ == "__main__":
    main()