import os import argparse import torch from safetensors.torch import save_file def extract_state_dict(checkpoint): """ Extracts the tensor dictionary from common .pth formats. """ if isinstance(checkpoint, dict): for key in ["state_dict", "model", "model_state_dict", "module"]: if key in checkpoint and isinstance(checkpoint[key], dict): return checkpoint[key] return checkpoint def convert_pth_to_safetensors(input_path, output_path=None): print(f"🔍 Loading checkpoint from: {input_path}") try: checkpoint = torch.load(input_path, map_location="cpu", weights_only=True) except Exception as e: print(f"❌ Error loading .pth file: {e}") return state_dict = extract_state_dict(checkpoint) if not isinstance(state_dict, dict): print("❌ Invalid checkpoint: not a dictionary.") return tensor_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} if not tensor_dict: print("❌ No tensor values found to convert.") return # Optionally add "model." prefix to HuggingFace-compatible keys if not all(k.startswith("model.") for k in tensor_dict): tensor_dict = {f"model.{k}": v for k, v in tensor_dict.items()} if output_path is None: output_path = os.path.splitext(input_path)[0] + ".safetensors" try: print(f"💾 Saving to: {output_path}") save_file(tensor_dict, output_path) print("✅ Conversion to .safetensors successful!") except Exception as e: print(f"❌ Saving failed: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert .pth to .safetensors") parser.add_argument("input", help="Path to input .pth file") parser.add_argument("--output", help="Path to output .safetensors file (optional)") args = parser.parse_args() convert_pth_to_safetensors(args.input, args.output)