File size: 4,411 Bytes
9af1f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
"""
bin_to_safetensors.py

Convert a PyTorch checkpoint (e.g., pytorch_model.bin / .pt / .ckpt) to a .safetensors file.
- Safe tensors only: tensors are saved; non-tensor Python objects (optimizer, schedulers, etc.) are ignored.
- Heuristics try to locate a model state_dict within common training checkpoints.

USAGE:
  python bin_to_safetensors.py --in pytorch_model.bin --out model.safetensors
  python bin_to_safetensors.py --in trainer.ckpt --out model.safetensors

NOTE:
  Loading with torch.load uses pickle and can execute code from untrusted sources.
  Only run this on checkpoints from sources you trust.
"""

import argparse
import sys
from typing import Dict, Any

import torch
from safetensors.torch import save_file, is_safe_tensor


def _is_tensor_dict(d: Any) -> bool:
    if not isinstance(d, dict) or not d:
        return False
    # Determine if all values are (Tensor | ShardedTensor-like)
    for v in d.values():
        if not (torch.is_tensor(v) or (hasattr(v, "tensor") and torch.is_tensor(getattr(v, "tensor")))):
            return False
    return True


def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
    """
    Try to extract a {name: tensor} dict from various checkpoint formats.
    """
    # Case 1: Already a tensor dict (typical HF "pytorch_model.bin")
    if _is_tensor_dict(obj):
        # Ensure tensors are on CPU and contiguous
        return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous())
                for k, v in obj.items()}

    if isinstance(obj, dict):
        # Common keys from training frameworks
        candidate_keys = [
            "state_dict",
            "model_state_dict",
            "model",
            "module",          # sometimes raw module.state_dict() is stored here
            "network",
            "net",
            "weights",
        ]

        for ck in candidate_keys:
            if ck in obj and _is_tensor_dict(obj[ck]):
                d = obj[ck]
                return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous())
                        for k, v in d.items()}

        # Lightning-style: sometimes stored under "state_dict" but with prefixes, or nested
        for k, v in obj.items():
            if _is_tensor_dict(v):
                d = v
                return {kk: (vv.detach().cpu().contiguous() if torch.is_tensor(vv) else vv.tensor.detach().cpu().contiguous())
                        for kk, vv in d.items()}

    raise ValueError(
        "Could not find a model state_dict (a dict of tensors). "
        "If this is a full training checkpoint, load it in Python, extract model.state_dict(), "
        "and save that mapping instead."
    )


def convert_bin_to_safetensors(in_path: str, out_path: str, metadata: Dict[str, str] = None) -> None:
    # TRUST WARNING: torch.load uses pickle. Only load from trusted files.
    obj = torch.load(in_path, map_location="cpu")

    # If the file is already safetensors, bail out politely.
    if isinstance(obj, (bytes, bytearray)) and is_safe_tensor(obj):
        print(f"Input appears to already be a safetensors file: {in_path}")
        return

    state = _extract_state_dict(obj)

    # Optional basic metadata
    meta = {"format": "converted-from-pytorch-bin"}
    if metadata:
        meta.update({str(k): str(v) for k, v in metadata.items()})

    # Save
    save_file(state, out_path, metadata=meta)
    print(f"✅ Wrote {out_path} with {len(state)} tensors.")


def main(argv=None):
    parser = argparse.ArgumentParser(description="Convert PyTorch .bin/.pt/.ckpt to .safetensors")
    parser.add_argument("--in", dest="in_path", required=True, help="Input .bin/.pt/.ckpt file path")
    parser.add_argument("--out", dest="out_path", required=True, help="Output .safetensors file path")
    parser.add_argument("--meta", nargs="*", default=[], help='Optional metadata entries like key=value (repeatable)')
    args = parser.parse_args(argv)

    metadata = {}
    for item in args.meta:
        if "=" in item:
            k, v = item.split("=", 1)
            metadata[k] = v
        else:
            print(f"Warning: ignoring malformed --meta entry (expected key=value): {item}", file=sys.stderr)

    convert_bin_to_safetensors(args.in_path, args.out_path, metadata)


if __name__ == "__main__":
    main()