Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,388 Bytes
400a879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch, os
from safetensors import safe_open
def load_state_dict_from_folder(file_path, torch_dtype=None, suffixs = ["safetensors", "bin", "ckpt", "pth", "pt"]):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in suffixs:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
|