File size: 1,318 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from safetensors.torch import save_file
import os

# Define the location and files to process
location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000"
files = ["model", "rank_0", "metadata"]

for file in files:
    pt_path = os.path.join(location, f"{file}.pt")
    st_path = os.path.join(location, f"{file}.safetensors")

    try:
        # Attempt to load the checkpoint with weights_only=True
        checkpoint = torch.load(pt_path, weights_only=True)
    except Exception as e:
        print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}")
        print("Attempting to load with weights_only=False (ensure the source is trusted).")
        try:
            checkpoint = torch.load(pt_path, weights_only=False)
        except Exception as e:
            print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}")
            continue  # Skip to the next file

    # Determine the state_dict
    state_dict = checkpoint.get('model', checkpoint)

    # Filter out non-tensor entries
    tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}

    # Save the filtered state_dict to a .safetensors file
    save_file(tensor_state_dict, st_path)
    print(f"Successfully converted {pt_path} to {st_path}")