Quiho commited on
Commit
9dbb8e1
·
verified ·
1 Parent(s): 799ef1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -28
app.py CHANGED
@@ -3,40 +3,78 @@ from safetensors.torch import save_file, load_file
3
  import gradio as gr
4
  import os
5
 
6
- def convert_embedding(uploaded_file):
7
- output_path = "embedding.safetensors"
8
- file_extension = os.path.splitext(uploaded_file.name)[1]
9
-
10
- #The sample files are probably structured in these ways because the pt files were probably all created with automatic1111, and the safetensors files were probably created with kohya_ss
11
- #If we learn of other programs that structure the embedding file differently, we'll have to adjust the logic.
12
- if file_extension == '.pt':
13
- sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'), weights_only=True)
14
- sd15_tensor = sd15_embedding['string_to_param']['*']
15
- elif file_extension == '.safetensors':
16
- loaded_tensors = load_file(uploaded_file.name)
17
- sd15_tensor = loaded_tensors['emb_params']
18
- else:
19
- raise ValueError("Unsupported file format")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- num_vectors = sd15_tensor.shape[0]
22
- clip_g_shape = (num_vectors, 1280)
23
- clip_l_shape = (num_vectors, 768)
24
- clip_g = torch.zeros(clip_g_shape, dtype=torch.float16)
25
- clip_l = torch.zeros(clip_l_shape, dtype=torch.float16)
26
- clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16)
27
- save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path)
28
 
29
- # Return the path to the converted file for download
30
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  iface = gr.Interface(
33
  fn=convert_embedding,
34
- inputs=gr.File(label="Upload SD-1.5 embedding"),
35
- outputs=gr.File(label="Download converted SDXL safetensors embedding"),
36
- title="SD-1.5 to SDXL Embedding Converter | Now running on Gradio 5 ⚡",
37
- description="Upload an SD-1.5 embedding file to convert it to SDXL."
 
 
38
  )
39
 
40
  if __name__ == "__main__":
41
  iface.launch()
42
-
 
3
  import gradio as gr
4
  import os
5
 
6
+ def convert_embedding(uploaded_files):
7
+ output_files = []
8
+
9
+ for uploaded_file in uploaded_files:
10
+ file_name, file_extension = os.path.splitext(uploaded_file.name)
11
+ output_path = f"{file_name}_XL.safetensors"
12
+
13
+ if file_extension == '.pt':
14
+ sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'))
15
+ sd15_tensor = sd15_embedding.get('string_to_param', {}).get('*')
16
+ elif file_extension == '.safetensors':
17
+ loaded_tensors = load_file(uploaded_file.name)
18
+ sd15_tensor = loaded_tensors.get('emb_params')
19
+ else:
20
+ raise ValueError(f"Unsupported file format: {file_extension}")
21
+
22
+ if sd15_tensor is None:
23
+ raise ValueError(f"Invalid embedding structure in file: {uploaded_file.name}")
24
+
25
+ num_vectors = sd15_tensor.shape[0]
26
+ clip_g_shape = (num_vectors, 1280)
27
+ clip_l_shape = (num_vectors, 768)
28
+ clip_g = torch.zeros(clip_g_shape, dtype=torch.float16)
29
+ clip_l = torch.zeros(clip_l_shape, dtype=torch.float16)
30
+ clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16)
31
+
32
+ save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path)
33
+ output_files.append(output_path)
34
 
35
+ return output_files
 
 
 
 
 
 
36
 
37
+ custom_css = """
38
+ body {
39
+ background-color: #1e1e2e;
40
+ color: #ffffff;
41
+ font-family: Arial, sans-serif;
42
+ }
43
+ .gradio-container {
44
+ max-width: 800px;
45
+ margin: auto;
46
+ padding: 20px;
47
+ border-radius: 10px;
48
+ background: #2a2a3a;
49
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.3);
50
+ }
51
+ .gradio-container h1 {
52
+ text-align: center;
53
+ font-size: 24px;
54
+ }
55
+ .gradio-container button {
56
+ background-color: #4caf50;
57
+ color: white;
58
+ padding: 10px 15px;
59
+ border: none;
60
+ border-radius: 5px;
61
+ cursor: pointer;
62
+ font-size: 16px;
63
+ }
64
+ .gradio-container button:hover {
65
+ background-color: #45a049;
66
+ }
67
+ """
68
 
69
  iface = gr.Interface(
70
  fn=convert_embedding,
71
+ inputs=gr.File(label="Upload SD-1.5 embeddings", type="file", multiple=True),
72
+ outputs=gr.File(label="Download converted SDXL safetensors embeddings", type="file", multiple=True),
73
+ title="SD-1.5 to SDXL Embedding Converter | Now supports multiple files ⚡",
74
+ description="Upload one or more SD-1.5 embedding files to convert them to SDXL. Stylish and efficient!",
75
+ theme="default",
76
+ css=custom_css
77
  )
78
 
79
  if __name__ == "__main__":
80
  iface.launch()