import os import gradio as gr import torch from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL from transformers import CLIPTextModel from safetensors.torch import load_file from collections import OrderedDict import re import json import gdown import requests import subprocess from urllib.parse import urlparse, unquote from pathlib import Path import tempfile from tqdm import tqdm import psutil import math import shutil import hashlib from datetime import datetime from typing import Dict, List, Optional from huggingface_hub import login, HfApi from huggingface_hub.errors import HfHubHTTPError # ---------------------- DEPENDENCIES ---------------------- def install_dependencies_gradio(): """Installs the necessary dependencies for the Gradio app. Run this ONCE.""" try: subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"]) print("Dependencies installed successfully.") except Exception as e: print(f"Error installing dependencies: {e}") # ---------------------- UTILITY FUNCTIONS ---------------------- def get_save_dtype(save_precision_as): """Determines the save dtype based on the user's choice.""" if save_precision_as == "fp16": return torch.float16 elif save_precision_as == "bf16": return torch.bfloat16 elif save_precision_as == "float": return torch.float32 else: return None def determine_load_checkpoint(model_to_load): """Determines if the model to load is a checkpoint or a Diffusers model.""" if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'): return True elif os.path.isdir(model_to_load): required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"} if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")): return False return None def increment_filename(filename): """Increments the filename to avoid overwriting existing files.""" base, ext = os.path.splitext(filename) counter = 1 while os.path.exists(filename): filename = f"{base}({counter}){ext}" counter += 1 return filename # ---------------------- UPLOAD FUNCTION ----------------------# ---------------------- UPLOAD FUNCTION ---------------------- def create_model_repo(api, user, orgs_name, model_name, make_private=False): """Creates a Hugging Face model repository if it doesn't exist.""" repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}" try: # Attempt to create the repository api.create_repo(repo_id=repo_id, repo_type="model", private=make_private) print(f"Model repo '{repo_id}' created.") except HfHubHTTPError: print(f"Model repo '{repo_id}' already exists.") return repo_id # ---------------------- MODEL LOADING AND CONVERSION ---------------------- def load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype): """Loads the SDXL model from a checkpoint or Diffusers model.""" model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if load_dtype == torch.float16 else "") print(f"Loading {model_load_message}: {model_to_load}") if is_load_checkpoint: return load_from_sdxl_checkpoint(model_to_load) else: return load_sdxl_from_diffusers(model_to_load, load_dtype) def load_from_sdxl_checkpoint(model_to_load): """Loads the SDXL model components from a checkpoint file.""" # Implement loading logic here text_encoder1, text_encoder2, vae, unet = None, None, None, None # Example loading logic (replace with actual loading code) # text_encoder1, text_encoder2, vae, unet = sdxl_model_util.load_models_from_sdxl_checkpoint("sdxl_base_v1-0", model_to_load, "cpu") print(f"Loaded from checkpoint: {model_to_load}") return text_encoder1, text_encoder2, vae, unet def load_sdxl_from_diffusers(model_to_load, load_dtype): """Loads an SDXL model from a Diffusers model directory.""" pipeline = StableDiffusionXLPipeline.from_pretrained(model_to_load, torch_dtype=load_dtype) text_encoder1 = pipeline.text_encoder text_encoder2 = pipeline.text_encoder_2 vae = pipeline.vae unet = pipeline.unet return text_encoder1, text_encoder2, vae, unet def convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, save_dtype): """Converts and saves the SDXL model as either a checkpoint or a Diffusers model.""" text_encoder1, text_encoder2, vae, unet = loaded_model_data if is_save_checkpoint: save_sdxl_as_checkpoint(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype) else: save_sdxl_as_diffusers(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype) def save_sdxl_as_checkpoint(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype): """Saves the SDXL model components as a checkpoint file.""" # Implement saving logic here print(f"Model saved as checkpoint: {model_to_save}") def save_sdxl_as_diffusers(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype): """Saves the SDXL model as a Diffusers model.""" pipeline = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder1, text_encoder_2=text_encoder2, unet=unet ) pipeline.save_pretrained(model_to_save) print(f"Model saved as Diffusers format: {model_to_save}") # ---------------------- UPLOAD FUNCTION ---------------------- def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private): """Uploads a model to the Hugging Face Hub.""" login(hf_token, add_to_git_credential=True) api = HfApi() user = api.whoami(hf_token) model_repo = create_model_repo(api, user, orgs_name, model_name, make_private) # Upload logic here print(f"Model uploaded to: https://huggingface.co/{model_repo}") # ---------------------- GRADIO INTERFACE ---------------------- def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private): """Main function orchestrating the entire process.""" load_dtype = get_save_dtype(save_precision_as) is_load_checkpoint = determine_load_checkpoint(model_to_load) is_save_checkpoint = not is_load_checkpoint loaded_model_data = load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype) convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, load_dtype) upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private) return "Conversion and upload completed successfully!" with gr.Blocks() as demo: model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model") save_precision_as = gr.Dropdown(choices=["fp16", "bf16", "float"], label="Save Precision As") epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)") global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)") reference_model = gr.Textbox(label="Reference Diffusers Model", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0") output_path = gr.Textbox(label="Output Path", value="/content/output") hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token") orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name") model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face") make_private = gr.Checkbox(label="Make Repository Private", value=False) convert_button = gr.Button("Convert and Upload") output = gr.Markdown() convert_button.click(fn=main, inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private], outputs=output) demo.launch()