Spaces:
Sleeping
Sleeping
| import torch | |
| from safetensors.torch import save_file, load_file | |
| import gradio as gr | |
| import os | |
| def convert_embedding(uploaded_file): | |
| output_path = "embedding.safetensors" | |
| file_extension = os.path.splitext(uploaded_file.name)[1] | |
| #The sample files are probably structured in these ways because the pt files were probably all created with automatic1111, and the safetensors files were probably created with kohya_ss | |
| #If we learn of other programs that structure the embedding file differently, we'll have to adjust the logic. | |
| if file_extension == '.pt': | |
| sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'), weights_only=True) | |
| sd15_tensor = sd15_embedding['string_to_param']['*'] | |
| elif file_extension == '.safetensors': | |
| loaded_tensors = load_file(uploaded_file.name) | |
| sd15_tensor = loaded_tensors['emb_params'] | |
| else: | |
| raise ValueError("Unsupported file format") | |
| num_vectors = sd15_tensor.shape[0] | |
| clip_g_shape = (num_vectors, 1280) | |
| clip_l_shape = (num_vectors, 768) | |
| clip_g = torch.zeros(clip_g_shape, dtype=torch.float16) | |
| clip_l = torch.zeros(clip_l_shape, dtype=torch.float16) | |
| clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16) | |
| save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path) | |
| # Return the path to the converted file for download | |
| return output_path | |
| iface = gr.Interface( | |
| fn=convert_embedding, | |
| inputs=gr.File(label="Upload SD1.5 embedding"), | |
| outputs=gr.File(label="Download converted SDXL safetensors embedding"), | |
| title="SD1.5 to SDXL Embedding Converter", | |
| description="Upload an SD1.5 embedding file to convert it to SDXL." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |