|
|
import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
from safetensors.torch import save_file
|
|
|
|
|
|
def extract_state_dict(checkpoint):
|
|
|
"""
|
|
|
Extracts the tensor dictionary from common .pth formats.
|
|
|
"""
|
|
|
if isinstance(checkpoint, dict):
|
|
|
for key in ["state_dict", "model", "model_state_dict", "module"]:
|
|
|
if key in checkpoint and isinstance(checkpoint[key], dict):
|
|
|
return checkpoint[key]
|
|
|
return checkpoint
|
|
|
|
|
|
def convert_pth_to_safetensors(input_path, output_path=None):
|
|
|
print(f"π Loading checkpoint from: {input_path}")
|
|
|
|
|
|
try:
|
|
|
checkpoint = torch.load(input_path, map_location="cpu", weights_only=True)
|
|
|
except Exception as e:
|
|
|
print(f"β Error loading .pth file: {e}")
|
|
|
return
|
|
|
|
|
|
state_dict = extract_state_dict(checkpoint)
|
|
|
|
|
|
if not isinstance(state_dict, dict):
|
|
|
print("β Invalid checkpoint: not a dictionary.")
|
|
|
return
|
|
|
|
|
|
tensor_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
|
|
|
|
|
if not tensor_dict:
|
|
|
print("β No tensor values found to convert.")
|
|
|
return
|
|
|
|
|
|
|
|
|
if not all(k.startswith("model.") for k in tensor_dict):
|
|
|
tensor_dict = {f"model.{k}": v for k, v in tensor_dict.items()}
|
|
|
|
|
|
if output_path is None:
|
|
|
output_path = os.path.splitext(input_path)[0] + ".safetensors"
|
|
|
|
|
|
try:
|
|
|
print(f"πΎ Saving to: {output_path}")
|
|
|
save_file(tensor_dict, output_path)
|
|
|
print("β
Conversion to .safetensors successful!")
|
|
|
except Exception as e:
|
|
|
print(f"β Saving failed: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Convert .pth to .safetensors")
|
|
|
parser.add_argument("input", help="Path to input .pth file")
|
|
|
parser.add_argument("--output", help="Path to output .safetensors file (optional)")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
convert_pth_to_safetensors(args.input, args.output)
|
|
|
|