|
|
|
|
|
""" |
|
|
bin_to_safetensors.py |
|
|
|
|
|
Convert a PyTorch checkpoint (e.g., pytorch_model.bin / .pt / .ckpt) to a .safetensors file. |
|
|
- Safe tensors only: tensors are saved; non-tensor Python objects (optimizer, schedulers, etc.) are ignored. |
|
|
- Heuristics try to locate a model state_dict within common training checkpoints. |
|
|
|
|
|
USAGE: |
|
|
python bin_to_safetensors.py --in pytorch_model.bin --out model.safetensors |
|
|
python bin_to_safetensors.py --in trainer.ckpt --out model.safetensors |
|
|
|
|
|
NOTE: |
|
|
Loading with torch.load uses pickle and can execute code from untrusted sources. |
|
|
Only run this on checkpoints from sources you trust. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from typing import Dict, Any |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import save_file, is_safe_tensor |
|
|
|
|
|
|
|
|
def _is_tensor_dict(d: Any) -> bool: |
|
|
if not isinstance(d, dict) or not d: |
|
|
return False |
|
|
|
|
|
for v in d.values(): |
|
|
if not (torch.is_tensor(v) or (hasattr(v, "tensor") and torch.is_tensor(getattr(v, "tensor")))): |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Try to extract a {name: tensor} dict from various checkpoint formats. |
|
|
""" |
|
|
|
|
|
if _is_tensor_dict(obj): |
|
|
|
|
|
return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) |
|
|
for k, v in obj.items()} |
|
|
|
|
|
if isinstance(obj, dict): |
|
|
|
|
|
candidate_keys = [ |
|
|
"state_dict", |
|
|
"model_state_dict", |
|
|
"model", |
|
|
"module", |
|
|
"network", |
|
|
"net", |
|
|
"weights", |
|
|
] |
|
|
|
|
|
for ck in candidate_keys: |
|
|
if ck in obj and _is_tensor_dict(obj[ck]): |
|
|
d = obj[ck] |
|
|
return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous()) |
|
|
for k, v in d.items()} |
|
|
|
|
|
|
|
|
for k, v in obj.items(): |
|
|
if _is_tensor_dict(v): |
|
|
d = v |
|
|
return {kk: (vv.detach().cpu().contiguous() if torch.is_tensor(vv) else vv.tensor.detach().cpu().contiguous()) |
|
|
for kk, vv in d.items()} |
|
|
|
|
|
raise ValueError( |
|
|
"Could not find a model state_dict (a dict of tensors). " |
|
|
"If this is a full training checkpoint, load it in Python, extract model.state_dict(), " |
|
|
"and save that mapping instead." |
|
|
) |
|
|
|
|
|
|
|
|
def convert_bin_to_safetensors(in_path: str, out_path: str, metadata: Dict[str, str] = None) -> None: |
|
|
|
|
|
obj = torch.load(in_path, map_location="cpu") |
|
|
|
|
|
|
|
|
if isinstance(obj, (bytes, bytearray)) and is_safe_tensor(obj): |
|
|
print(f"Input appears to already be a safetensors file: {in_path}") |
|
|
return |
|
|
|
|
|
state = _extract_state_dict(obj) |
|
|
|
|
|
|
|
|
meta = {"format": "converted-from-pytorch-bin"} |
|
|
if metadata: |
|
|
meta.update({str(k): str(v) for k, v in metadata.items()}) |
|
|
|
|
|
|
|
|
save_file(state, out_path, metadata=meta) |
|
|
print(f"✅ Wrote {out_path} with {len(state)} tensors.") |
|
|
|
|
|
|
|
|
def main(argv=None): |
|
|
parser = argparse.ArgumentParser(description="Convert PyTorch .bin/.pt/.ckpt to .safetensors") |
|
|
parser.add_argument("--in", dest="in_path", required=True, help="Input .bin/.pt/.ckpt file path") |
|
|
parser.add_argument("--out", dest="out_path", required=True, help="Output .safetensors file path") |
|
|
parser.add_argument("--meta", nargs="*", default=[], help='Optional metadata entries like key=value (repeatable)') |
|
|
args = parser.parse_args(argv) |
|
|
|
|
|
metadata = {} |
|
|
for item in args.meta: |
|
|
if "=" in item: |
|
|
k, v = item.split("=", 1) |
|
|
metadata[k] = v |
|
|
else: |
|
|
print(f"Warning: ignoring malformed --meta entry (expected key=value): {item}", file=sys.stderr) |
|
|
|
|
|
convert_bin_to_safetensors(args.in_path, args.out_path, metadata) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|