import os import gc import torch import shutil import uuid import gradio as gr from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file ARCH_PROFILES = { "FLUX / Generic Rectified Flow": ["norm", "ln_", "embed", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"], "Z-Image / DiT Core": ["t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared"], "Stable Diffusion (SDXL/SD3)": ["time_embed", "label_emb", "norm", "ln_", "out."] } def convert_and_upload(token, source_repo, target_repo, precision, target_components, arch_profile): if not token: yield "❌ Error: Please provide a valid Hugging Face Write Token." return if not target_repo.strip() or "/" not in target_repo: yield "❌ Error: Target Repository must be in format 'username/repo-name'." return if not target_components: yield "❌ Error: Please select at least one component to quantize." return # Map precision target_dtype = None is_int8 = precision == "INT8" if precision == "FP8": target_dtype = torch.float8_e4m3fn elif precision == "FP16": target_dtype = torch.float16 elif precision == "BF16": target_dtype = torch.bfloat16 api = HfApi(token=token) yield f"🔄 Verifying target repo: {target_repo}..." try: api.create_repo(repo_id=target_repo, exist_ok=True, private=False) except Exception as e: yield f"❌ Error creating repo: {str(e)}" return yield f"📋 Fetching files from {source_repo}..." try: files = api.list_repo_files(source_repo) except Exception as e: yield f"❌ Error fetching files: {str(e)}" return cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}" success_count, error_count = 0, 0 exclude_prefixes = ARCH_PROFILES.get(arch_profile, []) for file in files: if "/" not in file and file.endswith(".safetensors"): yield f"🗑️ Auto-skipping massive root model: {file}..." continue yield f"⏳ Processing {file}..." try: os.makedirs(cache_dir, exist_ok=True) local_path = hf_hub_download(repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token) in_target_component = any(f"{comp}/" in file for comp in target_components) if file.endswith(".safetensors") and in_target_component: yield f"🧠 Quantizing {file} to {precision}..." tensors = load_file(local_path) new_tensors = {} for k, v in tensors.items(): if is_int8: is_2d_weight = "weight" in k and len(v.shape) == 2 is_excluded = any(ex in k for ex in exclude_prefixes) if is_2d_weight and not is_excluded: if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16) scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0 scale = scale.clamp(min=1e-8) new_tensors[f"{k.rsplit('.', 1)[0]}.weight_int8"] = torch.round(v / scale).clamp(-127, 127).to(torch.int8) new_tensors[f"{k.rsplit('.', 1)[0]}.weight_scale"] = scale.to(torch.bfloat16) else: new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v else: new_tensors[k] = v.to(target_dtype) if v.is_floating_point() else v converted_path = "converted.safetensors" save_file(new_tensors, converted_path) del tensors, new_tensors gc.collect() yield f"☁️ Uploading {precision} version of {file}..." api.upload_file(path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo) os.remove(converted_path) else: yield f"☁️ Copying {file} as-is..." api.upload_file(path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo) success_count += 1 if os.path.exists(cache_dir): shutil.rmtree(cache_dir) gc.collect() except Exception as e: error_count += 1 yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..." if os.path.exists(cache_dir): shutil.rmtree(cache_dir) yield f"✅ Finished! Processed: {success_count} | Errors: {error_count}." # --- UI LOGIC --- def generate_target_repo(source, precision): model_name = source.split("/")[-1] if "/" in source else source return f"your-username/{model_name}-{precision}" def toggle_int8_warning(precision): return gr.update(visible=(precision == "INT8")) # --- GUI --- # FIXED: Removed the theme argument from gr.Blocks() with gr.Blocks() as demo: gr.Markdown( """ # ⚡ Universal Model Quantizer Hub Convert massive diffusion and transformer models directly on the Hugging Face hub. Engineered with aggressive cache-clearing to prevent storage crashes on free-tier Spaces. """ ) with gr.Row(): with gr.Column(scale=5): with gr.Tabs(): with gr.TabItem("1. Authentication & Source"): hf_token = gr.Textbox(label="HF Access Token (Write)", type="password", placeholder="hf_...") source_repo = gr.Textbox( label="Source Repository", placeholder="e.g., black-forest-labs/FLUX.1-dev", info="Paste any Hugging Face model repository ID." ) gr.Markdown("### Popular Presets") with gr.Row(): preset_flux = gr.Button("FLUX.2-klein-9B", size="sm") preset_zimage = gr.Button("Z-Image-Turbo", size="sm") preset_sd3 = gr.Button("SD3.5-Large", size="sm") with gr.TabItem("2. Quantization Rules"): arch_profile = gr.Radio( choices=list(ARCH_PROFILES.keys()), value="FLUX / Generic Rectified Flow", label="Architecture Profile", info="Crucial for INT8: Selects which layers to protect from precision loss." ) target_components = gr.CheckboxGroup( choices=["transformer", "text_encoder", "text_encoder_2", "vae"], value=["transformer"], label="Folders to Quantize", info="Unselected folders will be copied to the new repo unchanged." ) with gr.TabItem("3. Output Settings"): precision = gr.Dropdown( choices=["FP8", "FP16", "BF16", "INT8"], value="INT8", label="Target Precision" ) int8_warning = gr.Markdown( "⚠️ **INT8 Selected:** Keys will be split into `weight_int8` and `weight_scale`. " "Requires custom XPU/CUDA native linear classes to execute.", visible=True ) target_repo = gr.Textbox( label="Target Repository", placeholder="your-username/model-name", interactive=True ) start_btn = gr.Button("🚀 Start Cloud Quantization", variant="primary", size="lg") with gr.Column(scale=4): output_log = gr.Textbox( label="Terminal Output", lines=24, interactive=False, max_lines=30 ) preset_flux.click(lambda: ("black-forest-labs/FLUX.2-klein-9B", "FLUX / Generic Rectified Flow"), outputs=[source_repo, arch_profile]) preset_zimage.click(lambda: ("your-username/Z-Image-Turbo", "Z-Image / DiT Core"), outputs=[source_repo, arch_profile]) preset_sd3.click(lambda: ("stabilityai/stable-diffusion-3.5-large", "Stable Diffusion (SDXL/SD3)"), outputs=[source_repo, arch_profile]) source_repo.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo]) precision.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo]) precision.change(fn=toggle_int8_warning, inputs=[precision], outputs=[int8_warning]) start_btn.click( fn=convert_and_upload, inputs=[hf_token, source_repo, target_repo, precision, target_components, arch_profile], outputs=[output_log] ) if __name__ == "__main__": demo.launch(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"))