camenduru commited on
Commit
3c26d57
·
1 Parent(s): a533226

thanks to facechain ❤

Browse files
test/convert-to-safetensors.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/huggingface/diffusers/issues/2326 by https://github.com/ignacfetser
3
+
4
+ The LoRA trained using Diffusers are saved in .bin or .pkl format, which must be converted to be used in Automatic1111 WebUI.
5
+
6
+ This script converts .bin or .pkl files into .safetensors format, which can be used in WebUI.
7
+
8
+ Put this file in the same folder of .bin or .pkl file and run `python convert-to-safetensors.py --file checkpoint_file`
9
+
10
+ """
11
+ import re
12
+ import os
13
+ import argparse
14
+ import torch;
15
+ from safetensors.torch import save_file
16
+
17
+ def main(args):
18
+
19
+ ## use GPU or CPU
20
+ if torch.cuda.is_available():
21
+ device = 'cuda'
22
+ checkpoint = torch.load(args.file, map_location=torch.device('cuda'))
23
+ else:
24
+ device = 'cpu'
25
+ # if on CPU or want to have maximum precision on GPU, use default full-precision setting
26
+ checkpoint = torch.load(args.file, map_location=torch.device('cpu'))
27
+
28
+ print(f'device is {device}')
29
+
30
+
31
+ new_dict = dict()
32
+ for idx, key in enumerate(checkpoint):
33
+ new_key = re.sub('\.processor\.', '_', key)
34
+ new_key = re.sub('mid_block\.', 'mid_block_', new_key)
35
+ new_key = re.sub('_lora.up.', '.lora_up.', new_key)
36
+ new_key = re.sub('_lora.down.', '.lora_down.', new_key)
37
+ new_key = re.sub('\.(\d+)\.', '_\\1_', new_key)
38
+ new_key = re.sub('to_out', 'to_out_0', new_key)
39
+ new_key = 'lora_unet_' + new_key
40
+
41
+ new_dict[new_key] = checkpoint[key]
42
+
43
+ file_name = os.path.splitext(args.file)[0] # get the file name without the extension
44
+ new_lora_name = file_name + '_converted.safetensors'
45
+ print("Saving " + new_lora_name)
46
+ save_file(new_dict, new_lora_name)
47
+
48
+
49
+ def parse_args():
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument(
52
+ "--file",
53
+ type=str,
54
+ default=None,
55
+ required=True,
56
+ help="path to the full file name",
57
+ )
58
+
59
+ args = parser.parse_args()
60
+ return args
61
+
62
+
63
+ if __name__ == "__main__":
64
+ args = parse_args()
65
+ main(args)
test/pytorch_lora_weights_converted.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10de9be01dd9693ffea500c5c6924b0c039bad9d00167e17abfe9ab2d23a3546
3
+ size 102079400