nnn-3 / bin_to_safetensors.py
Alignment-Lab-AI's picture
Duplicate from xeophon/NVIDIA-Nemotron-Nano-3-30B-A3B-BF16
9af1f06 verified
#!/usr/bin/env python3
"""
bin_to_safetensors.py
Convert a PyTorch checkpoint (e.g., pytorch_model.bin / .pt / .ckpt) to a .safetensors file.
- Safe tensors only: tensors are saved; non-tensor Python objects (optimizer, schedulers, etc.) are ignored.
- Heuristics try to locate a model state_dict within common training checkpoints.
USAGE:
python bin_to_safetensors.py --in pytorch_model.bin --out model.safetensors
python bin_to_safetensors.py --in trainer.ckpt --out model.safetensors
NOTE:
Loading with torch.load uses pickle and can execute code from untrusted sources.
Only run this on checkpoints from sources you trust.
"""
import argparse
import sys
from typing import Dict, Any
import torch
from safetensors.torch import save_file, is_safe_tensor
def _is_tensor_dict(d: Any) -> bool:
if not isinstance(d, dict) or not d:
return False
# Determine if all values are (Tensor | ShardedTensor-like)
for v in d.values():
if not (torch.is_tensor(v) or (hasattr(v, "tensor") and torch.is_tensor(getattr(v, "tensor")))):
return False
return True
def _extract_state_dict(obj: Any) -> Dict[str, torch.Tensor]:
"""
Try to extract a {name: tensor} dict from various checkpoint formats.
"""
# Case 1: Already a tensor dict (typical HF "pytorch_model.bin")
if _is_tensor_dict(obj):
# Ensure tensors are on CPU and contiguous
return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous())
for k, v in obj.items()}
if isinstance(obj, dict):
# Common keys from training frameworks
candidate_keys = [
"state_dict",
"model_state_dict",
"model",
"module", # sometimes raw module.state_dict() is stored here
"network",
"net",
"weights",
]
for ck in candidate_keys:
if ck in obj and _is_tensor_dict(obj[ck]):
d = obj[ck]
return {k: (v.detach().cpu().contiguous() if torch.is_tensor(v) else v.tensor.detach().cpu().contiguous())
for k, v in d.items()}
# Lightning-style: sometimes stored under "state_dict" but with prefixes, or nested
for k, v in obj.items():
if _is_tensor_dict(v):
d = v
return {kk: (vv.detach().cpu().contiguous() if torch.is_tensor(vv) else vv.tensor.detach().cpu().contiguous())
for kk, vv in d.items()}
raise ValueError(
"Could not find a model state_dict (a dict of tensors). "
"If this is a full training checkpoint, load it in Python, extract model.state_dict(), "
"and save that mapping instead."
)
def convert_bin_to_safetensors(in_path: str, out_path: str, metadata: Dict[str, str] = None) -> None:
# TRUST WARNING: torch.load uses pickle. Only load from trusted files.
obj = torch.load(in_path, map_location="cpu")
# If the file is already safetensors, bail out politely.
if isinstance(obj, (bytes, bytearray)) and is_safe_tensor(obj):
print(f"Input appears to already be a safetensors file: {in_path}")
return
state = _extract_state_dict(obj)
# Optional basic metadata
meta = {"format": "converted-from-pytorch-bin"}
if metadata:
meta.update({str(k): str(v) for k, v in metadata.items()})
# Save
save_file(state, out_path, metadata=meta)
print(f"✅ Wrote {out_path} with {len(state)} tensors.")
def main(argv=None):
parser = argparse.ArgumentParser(description="Convert PyTorch .bin/.pt/.ckpt to .safetensors")
parser.add_argument("--in", dest="in_path", required=True, help="Input .bin/.pt/.ckpt file path")
parser.add_argument("--out", dest="out_path", required=True, help="Output .safetensors file path")
parser.add_argument("--meta", nargs="*", default=[], help='Optional metadata entries like key=value (repeatable)')
args = parser.parse_args(argv)
metadata = {}
for item in args.meta:
if "=" in item:
k, v = item.split("=", 1)
metadata[k] = v
else:
print(f"Warning: ignoring malformed --meta entry (expected key=value): {item}", file=sys.stderr)
convert_bin_to_safetensors(args.in_path, args.out_path, metadata)
if __name__ == "__main__":
main()