import torch, os from safetensors import safe_open def load_state_dict_from_folder(file_path, torch_dtype=None, suffixs = ["safetensors", "bin", "ckpt", "pth", "pt"]): state_dict = {} for file_name in os.listdir(file_path): if "." in file_name and file_name.split(".")[-1] in suffixs: state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) return state_dict def load_state_dict(file_path, torch_dtype=None): if file_path.endswith(".safetensors"): return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) else: return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) def load_state_dict_from_safetensors(file_path, torch_dtype=None): state_dict = {} with safe_open(file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: state_dict[k] = state_dict[k].to(torch_dtype) return state_dict def load_state_dict_from_bin(file_path, torch_dtype=None): state_dict = torch.load(file_path, map_location="cpu", weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): state_dict[i] = state_dict[i].to(torch_dtype) return state_dict