Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL | |
| from transformers import CLIPTextModel | |
| from safetensors.torch import load_file | |
| from collections import OrderedDict | |
| import re | |
| import json | |
| import gdown | |
| import requests | |
| import subprocess | |
| from urllib.parse import urlparse, unquote | |
| from pathlib import Path | |
| import tempfile | |
| from tqdm import tqdm | |
| import psutil | |
| import math | |
| import shutil | |
| import hashlib | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from huggingface_hub import login, HfApi | |
| from huggingface_hub.errors import HfHubHTTPError | |
| # ---------------------- DEPENDENCIES ---------------------- | |
| def install_dependencies_gradio(): | |
| """Installs the necessary dependencies for the Gradio app. Run this ONCE.""" | |
| try: | |
| subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"]) | |
| print("Dependencies installed successfully.") | |
| except Exception as e: | |
| print(f"Error installing dependencies: {e}") | |
| # ---------------------- UTILITY FUNCTIONS ---------------------- | |
| def get_save_dtype(save_precision_as): | |
| """Determines the save dtype based on the user's choice.""" | |
| if save_precision_as == "fp16": | |
| return torch.float16 | |
| elif save_precision_as == "bf16": | |
| return torch.bfloat16 | |
| elif save_precision_as == "float": | |
| return torch.float32 | |
| else: | |
| return None | |
| def determine_load_checkpoint(model_to_load): | |
| """Determines if the model to load is a checkpoint or a Diffusers model.""" | |
| if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'): | |
| return True | |
| elif os.path.isdir(model_to_load): | |
| required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"} | |
| if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")): | |
| return False | |
| return None | |
| def increment_filename(filename): | |
| """Increments the filename to avoid overwriting existing files.""" | |
| base, ext = os.path.splitext(filename) | |
| counter = 1 | |
| while os.path.exists(filename): | |
| filename = f"{base}({counter}){ext}" | |
| counter += 1 | |
| return filename | |
| # ---------------------- UPLOAD FUNCTION ----------------------# ---------------------- UPLOAD FUNCTION ---------------------- | |
| def create_model_repo(api, user, orgs_name, model_name, make_private=False): | |
| """Creates a Hugging Face model repository if it doesn't exist.""" | |
| repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}" | |
| try: | |
| # Attempt to create the repository | |
| api.create_repo(repo_id=repo_id, repo_type="model", private=make_private) | |
| print(f"Model repo '{repo_id}' created.") | |
| except HfHubHTTPError: | |
| print(f"Model repo '{repo_id}' already exists.") | |
| return repo_id | |
| # ---------------------- MODEL LOADING AND CONVERSION ---------------------- | |
| def load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype): | |
| """Loads the SDXL model from a checkpoint or Diffusers model.""" | |
| model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if load_dtype == torch.float16 else "") | |
| print(f"Loading {model_load_message}: {model_to_load}") | |
| if is_load_checkpoint: | |
| return load_from_sdxl_checkpoint(model_to_load) | |
| else: | |
| return load_sdxl_from_diffusers(model_to_load, load_dtype) | |
| def load_from_sdxl_checkpoint(model_to_load): | |
| """Loads the SDXL model components from a checkpoint file.""" | |
| # Implement loading logic here | |
| text_encoder1, text_encoder2, vae, unet = None, None, None, None | |
| # Example loading logic (replace with actual loading code) | |
| # text_encoder1, text_encoder2, vae, unet = sdxl_model_util.load_models_from_sdxl_checkpoint("sdxl_base_v1-0", model_to_load, "cpu") | |
| print(f"Loaded from checkpoint: {model_to_load}") | |
| return text_encoder1, text_encoder2, vae, unet | |
| def load_sdxl_from_diffusers(model_to_load, load_dtype): | |
| """Loads an SDXL model from a Diffusers model directory.""" | |
| pipeline = StableDiffusionXLPipeline.from_pretrained(model_to_load, torch_dtype=load_dtype) | |
| text_encoder1 = pipeline.text_encoder | |
| text_encoder2 = pipeline.text_encoder_2 | |
| vae = pipeline.vae | |
| unet = pipeline.unet | |
| return text_encoder1, text_encoder2, vae, unet | |
| def convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, save_dtype): | |
| """Converts and saves the SDXL model as either a checkpoint or a Diffusers model.""" | |
| text_encoder1, text_encoder2, vae, unet = loaded_model_data | |
| if is_save_checkpoint: | |
| save_sdxl_as_checkpoint(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype) | |
| else: | |
| save_sdxl_as_diffusers(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype) | |
| def save_sdxl_as_checkpoint(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype): | |
| """Saves the SDXL model components as a checkpoint file.""" | |
| # Implement saving logic here | |
| print(f"Model saved as checkpoint: {model_to_save}") | |
| def save_sdxl_as_diffusers(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype): | |
| """Saves the SDXL model as a Diffusers model.""" | |
| pipeline = StableDiffusionXLPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder1, | |
| text_encoder_2=text_encoder2, | |
| unet=unet | |
| ) | |
| pipeline.save_pretrained(model_to_save) | |
| print(f"Model saved as Diffusers format: {model_to_save}") | |
| # ---------------------- UPLOAD FUNCTION ---------------------- | |
| def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private): | |
| """Uploads a model to the Hugging Face Hub.""" | |
| login(hf_token, add_to_git_credential=True) | |
| api = HfApi() | |
| user = api.whoami(hf_token) | |
| model_repo = create_model_repo(api, user, orgs_name, model_name, make_private) | |
| # Upload logic here | |
| print(f"Model uploaded to: https://huggingface.co/{model_repo}") | |
| # ---------------------- GRADIO INTERFACE ---------------------- | |
| def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private): | |
| """Main function orchestrating the entire process.""" | |
| load_dtype = get_save_dtype(save_precision_as) | |
| is_load_checkpoint = determine_load_checkpoint(model_to_load) | |
| is_save_checkpoint = not is_load_checkpoint | |
| loaded_model_data = load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype) | |
| convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, load_dtype) | |
| upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private) | |
| return "Conversion and upload completed successfully!" | |
| with gr.Blocks() as demo: | |
| model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model") | |
| save_precision_as = gr.Dropdown(choices=["fp16", "bf16", "float"], label="Save Precision As") | |
| epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)") | |
| global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)") | |
| reference_model = gr.Textbox(label="Reference Diffusers Model", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0") | |
| output_path = gr.Textbox(label="Output Path", value="/content/output") | |
| hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token") | |
| orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name") | |
| model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face") | |
| make_private = gr.Checkbox(label="Make Repository Private", value=False) | |
| convert_button = gr.Button("Convert and Upload") | |
| output = gr.Markdown() | |
| convert_button.click(fn=main, inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private], outputs=output) | |
| demo.launch() |