Symbiomatrix commited on
Commit
7c6646d
·
verified ·
1 Parent(s): 89ed840

Create safetensors_converter

Browse files
Files changed (1) hide show
  1. safetensors_converter +72 -0
safetensors_converter ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/DiffusionDalmation/pt_to_safetensors_converter_notebook
2
+
3
+ import os
4
+ from typing import Any, Dict
5
+ import torch
6
+ from safetensors.torch import save_file
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ # Supported file extensions
11
+ SUPPORTED_EXTENSIONS = ['.pth', '.pt', '.bin', '.ckpt']
12
+
13
+ def is_supported_file(input_path):
14
+ """Check if file has a supported extension"""
15
+ _, ext = os.path.splitext(input_path.lower())
16
+ return ext in SUPPORTED_EXTENSIONS
17
+
18
+ def convert_file(input_path, output_path):
19
+ # Load the PyTorch model
20
+ model = torch.load(input_path, map_location=device)
21
+
22
+ if "string_to_param" in model: # Embeddings are a bit of a weird dict.
23
+ (s_model, dmeta) = process_embedding_file(model)
24
+ elif "state_dict" in model: # Ckpts or vaes are a standard list of layers.
25
+ (s_model, dmeta) = process_ckpt_file(model)
26
+ else: # No clue, try simple conversion.
27
+ s_model = model
28
+ dmeta = {"ckpt": None, "step": None, "dim": None}
29
+
30
+ try:
31
+ save_file (s_model, output_path)
32
+ except Exception as e:
33
+ raise ValueError(f"Unknown filetype: {input_path} | {str(e)}")
34
+
35
+ return dmeta
36
+
37
+ def process_embedding_file(model):
38
+ # Extract the embedding tensors
39
+ model_tensors = model.get('string_to_param').get('*')
40
+ s_model = {
41
+ 'emb_params': model_tensors
42
+ }
43
+
44
+ # Metadata extraction.
45
+ dmeta = {"ckpt": None, "step": None, "dim": None}
46
+ if ('sd_checkpoint_name' in model) and (model['sd_checkpoint_name'] is not None):
47
+ dmeta["ckpt"] = model['sd_checkpoint_name']
48
+ if ('step' in model) and (model['step'] is not None):
49
+ dmeta["step"] = model["step"]
50
+ dmeta["dim"] = model_tensors.shape
51
+
52
+ return s_model, dmeta
53
+
54
+ def process_ckpt_file(model):
55
+ # Extract the state dictionary
56
+ s_model = model["state_dict"]
57
+
58
+ # Metadata extraction.
59
+ dmeta = {"ckpt": None, "step": None, "dim": None}
60
+ dmeta["step"] = model.get('step', model.get('global_step'))
61
+
62
+ return s_model, dmeta
63
+
64
+ if verbose:
65
+ # Print the requested training information, if it exists
66
+ step = model.get('step', model.get('global_step'))
67
+ if step is not None:
68
+ print(f"Trained for {step} steps.")
69
+ else:
70
+ print("Step not found in the model.")
71
+ print()
72
+ return s_model