Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| import torch | |
| from safetensors.torch import save_file | |
| import os | |
| # Define the location and files to process | |
| location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000" | |
| files = ["model", "rank_0", "metadata"] | |
| for file in files: | |
| pt_path = os.path.join(location, f"{file}.pt") | |
| st_path = os.path.join(location, f"{file}.safetensors") | |
| try: | |
| # Attempt to load the checkpoint with weights_only=True | |
| checkpoint = torch.load(pt_path, weights_only=True) | |
| except Exception as e: | |
| print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}") | |
| print("Attempting to load with weights_only=False (ensure the source is trusted).") | |
| try: | |
| checkpoint = torch.load(pt_path, weights_only=False) | |
| except Exception as e: | |
| print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}") | |
| continue # Skip to the next file | |
| # Determine the state_dict | |
| state_dict = checkpoint.get('model', checkpoint) | |
| # Filter out non-tensor entries | |
| tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} | |
| # Save the filtered state_dict to a .safetensors file | |
| save_file(tensor_state_dict, st_path) | |
| print(f"Successfully converted {pt_path} to {st_path}") | |