File size: 2,031 Bytes
8e36426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import argparse
import torch
from safetensors.torch import save_file

def extract_state_dict(checkpoint):
    """

    Extracts the tensor dictionary from common .pth formats.

    """
    if isinstance(checkpoint, dict):
        for key in ["state_dict", "model", "model_state_dict", "module"]:
            if key in checkpoint and isinstance(checkpoint[key], dict):
                return checkpoint[key]
    return checkpoint

def convert_pth_to_safetensors(input_path, output_path=None):
    print(f"๐Ÿ” Loading checkpoint from: {input_path}")

    try:
        checkpoint = torch.load(input_path, map_location="cpu", weights_only=True)
    except Exception as e:
        print(f"โŒ Error loading .pth file: {e}")
        return

    state_dict = extract_state_dict(checkpoint)

    if not isinstance(state_dict, dict):
        print("โŒ Invalid checkpoint: not a dictionary.")
        return

    tensor_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}

    if not tensor_dict:
        print("โŒ No tensor values found to convert.")
        return

    # Optionally add "model." prefix to HuggingFace-compatible keys
    if not all(k.startswith("model.") for k in tensor_dict):
        tensor_dict = {f"model.{k}": v for k, v in tensor_dict.items()}

    if output_path is None:
        output_path = os.path.splitext(input_path)[0] + ".safetensors"

    try:
        print(f"๐Ÿ’พ Saving to: {output_path}")
        save_file(tensor_dict, output_path)
        print("โœ… Conversion to .safetensors successful!")
    except Exception as e:
        print(f"โŒ Saving failed: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert .pth to .safetensors")
    parser.add_argument("input", help="Path to input .pth file")
    parser.add_argument("--output", help="Path to output .safetensors file (optional)")

    args = parser.parse_args()
    convert_pth_to_safetensors(args.input, args.output)