File size: 2,031 Bytes
8e36426 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import os
import argparse
import torch
from safetensors.torch import save_file
def extract_state_dict(checkpoint):
"""
Extracts the tensor dictionary from common .pth formats.
"""
if isinstance(checkpoint, dict):
for key in ["state_dict", "model", "model_state_dict", "module"]:
if key in checkpoint and isinstance(checkpoint[key], dict):
return checkpoint[key]
return checkpoint
def convert_pth_to_safetensors(input_path, output_path=None):
print(f"๐ Loading checkpoint from: {input_path}")
try:
checkpoint = torch.load(input_path, map_location="cpu", weights_only=True)
except Exception as e:
print(f"โ Error loading .pth file: {e}")
return
state_dict = extract_state_dict(checkpoint)
if not isinstance(state_dict, dict):
print("โ Invalid checkpoint: not a dictionary.")
return
tensor_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
if not tensor_dict:
print("โ No tensor values found to convert.")
return
# Optionally add "model." prefix to HuggingFace-compatible keys
if not all(k.startswith("model.") for k in tensor_dict):
tensor_dict = {f"model.{k}": v for k, v in tensor_dict.items()}
if output_path is None:
output_path = os.path.splitext(input_path)[0] + ".safetensors"
try:
print(f"๐พ Saving to: {output_path}")
save_file(tensor_dict, output_path)
print("โ
Conversion to .safetensors successful!")
except Exception as e:
print(f"โ Saving failed: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert .pth to .safetensors")
parser.add_argument("input", help="Path to input .pth file")
parser.add_argument("--output", help="Path to output .safetensors file (optional)")
args = parser.parse_args()
convert_pth_to_safetensors(args.input, args.output)
|