import torch from safetensors.torch import save_file import os # Define the location and files to process location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000" files = ["model", "rank_0", "metadata"] for file in files: pt_path = os.path.join(location, f"{file}.pt") st_path = os.path.join(location, f"{file}.safetensors") try: # Attempt to load the checkpoint with weights_only=True checkpoint = torch.load(pt_path, weights_only=True) except Exception as e: print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}") print("Attempting to load with weights_only=False (ensure the source is trusted).") try: checkpoint = torch.load(pt_path, weights_only=False) except Exception as e: print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}") continue # Skip to the next file # Determine the state_dict state_dict = checkpoint.get('model', checkpoint) # Filter out non-tensor entries tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} # Save the filtered state_dict to a .safetensors file save_file(tensor_state_dict, st_path) print(f"Successfully converted {pt_path} to {st_path}")