Spaces:
Running
on
Zero
Running
on
Zero
| import torch, os | |
| from safetensors import safe_open | |
| def load_state_dict_from_folder(file_path, torch_dtype=None, suffixs = ["safetensors", "bin", "ckpt", "pth", "pt"]): | |
| state_dict = {} | |
| for file_name in os.listdir(file_path): | |
| if "." in file_name and file_name.split(".")[-1] in suffixs: | |
| state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) | |
| return state_dict | |
| def load_state_dict(file_path, torch_dtype=None): | |
| if file_path.endswith(".safetensors"): | |
| return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) | |
| else: | |
| return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) | |
| def load_state_dict_from_safetensors(file_path, torch_dtype=None): | |
| state_dict = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| state_dict[k] = f.get_tensor(k) | |
| if torch_dtype is not None: | |
| state_dict[k] = state_dict[k].to(torch_dtype) | |
| return state_dict | |
| def load_state_dict_from_bin(file_path, torch_dtype=None): | |
| state_dict = torch.load(file_path, map_location="cpu", weights_only=True) | |
| if torch_dtype is not None: | |
| for i in state_dict: | |
| if isinstance(state_dict[i], torch.Tensor): | |
| state_dict[i] = state_dict[i].to(torch_dtype) | |
| return state_dict | |