| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| import os |
| import torch |
| from safetensors.torch import save_file |
|
|
| files = os.listdir() |
|
|
| |
| models = [] |
| safeTensors = [] |
| for path, subdirs, files in os.walk(os.path.abspath(os.getcwd())): |
| for name in files: |
| if name.lower().endswith('.ckpt'): |
| models.append(os.path.join(path, name)) |
| if name.lower().endswith('.safetensors'): |
| safeTensors.append(os.path.join(path, name)) |
|
|
| if len(models) == 0: |
| print('\033[91m> No .ckpt files found in this directory ({}).\033[0m'.format(os.path.abspath(os.getcwd()))) |
| input('> Press enter to exit... ') |
| exit() |
| print(f"\n\033[92m> Found {len(models)} .ckpt files to convert.\033[0m") |
| for model in models: |
| print(str(models.index(model)+1) +": "+ model.split("\\")[-1]) |
|
|
| input("> Press enter to continue... ") |
| print("\n") |
|
|
| for index in range(len(models)): |
| f = models[index] |
| modelName = f.split("\\")[-1] |
| tensorName = f"{modelName.replace('.ckpt', '')}.safetensors" |
| fn = f"{f.replace('.ckpt', '')}.safetensors" |
|
|
| if fn in safeTensors: |
| |
| print(f"\033[33m\n> Skipping {modelName}, as {tensorName} already exists.\033[0m") |
| continue |
| |
| print(f'\n> Loading {modelName} ({index+1}/{len(models)})...') |
|
|
| try: |
| with torch.no_grad(): |
| map_location = torch.device('cpu') |
| weights = torch.load(f, map_location=map_location) |
| |
| |
| |
| fn = f"{f.replace('.ckpt', '')}.safetensors" |
| print(f'Saving {tensorName}...') |
| save_file(weights, fn) |
| except Exception as ex: |
| print(f'ERROR converting {modelName}: {ex}') |
|
|
| print("\n\033[92mDone!\033[0m") |
| input("> Press enter to exit... ") |
|
|