File size: 1,318 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from safetensors.torch import save_file
import os
# Define the location and files to process
location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000"
files = ["model", "rank_0", "metadata"]
for file in files:
pt_path = os.path.join(location, f"{file}.pt")
st_path = os.path.join(location, f"{file}.safetensors")
try:
# Attempt to load the checkpoint with weights_only=True
checkpoint = torch.load(pt_path, weights_only=True)
except Exception as e:
print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}")
print("Attempting to load with weights_only=False (ensure the source is trusted).")
try:
checkpoint = torch.load(pt_path, weights_only=False)
except Exception as e:
print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}")
continue # Skip to the next file
# Determine the state_dict
state_dict = checkpoint.get('model', checkpoint)
# Filter out non-tensor entries
tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
# Save the filtered state_dict to a .safetensors file
save_file(tensor_state_dict, st_path)
print(f"Successfully converted {pt_path} to {st_path}")
|