bssrdf commited on
Commit
2c20f54
·
verified ·
1 Parent(s): 50de937

Upload convert_to_safe.py

Browse files

added python script to convert original .bin to .safetensors with name changes.

Files changed (1) hide show
  1. convert_to_safe.py +51 -0
convert_to_safe.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Got a bunch of .ckpt files to convert?
2
+ # Here's a handy script to take care of all that for you!
3
+ # Original .ckpt files are not touched!
4
+ # Make sure you have enough disk space! You are going to DOUBLE the size of your models folder!
5
+ #
6
+ # First, run:
7
+ # pip install torch torchsde==0.2.5 safetensors==0.2.5
8
+ #
9
+ # Place this file in the **SAME DIRECTORY** as all of your .ckpt files, open a command prompt for that folder, and run:
10
+ # python convert_to_safe.py
11
+
12
+ import os
13
+ import torch
14
+ from safetensors.torch import save_file
15
+
16
+ files = os.listdir()
17
+ for f in files:
18
+ if f.lower().endswith('.bin'):
19
+ print(f'Loading {f}...')
20
+ fn = f"{f.replace('.bin', '')}.safetensors"
21
+
22
+ if fn in files:
23
+ print(f'Skipping, as {fn} already exists.')
24
+ continue
25
+
26
+ try:
27
+ with torch.no_grad():
28
+ state_dict = torch.load(f, map_location="cpu")
29
+ id_encoder = state_dict["id_encoder"]
30
+ to_be_changed = []
31
+ for key in id_encoder:
32
+ #print(key)
33
+ if "layrnorm" in key:
34
+ print(key)
35
+ newkey = key.replace("layrnorm", "layernorm")
36
+ to_be_changed.append((newkey, key))
37
+ if "visual_projection.w" in key:
38
+ print(key)
39
+ newkey = "vision_model."+key
40
+ to_be_changed.append((newkey, key))
41
+ for nkey,okey in to_be_changed:
42
+ id_encoder[nkey] = id_encoder.pop(okey)
43
+ lora = state_dict["lora_weights"]
44
+ weights = id_encoder | lora
45
+ #fn = f"{f.replace('.bin', '')}.safetensors"
46
+ print(f'Saving {fn}...')
47
+ save_file(weights, fn)
48
+ except Exception as ex:
49
+ print(f'ERROR converting {f}: {ex}')
50
+
51
+ print('Done!')