Spaces:
Running
Running
| import os | |
| import gc | |
| import torch | |
| import shutil | |
| import uuid | |
| import gradio as gr | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from safetensors.torch import load_file, save_file | |
| ARCH_PROFILES = { | |
| "FLUX / Generic Rectified Flow": ["norm", "ln_", "embed", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"], | |
| "Z-Image / DiT Core": ["t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared"], | |
| "Stable Diffusion (SDXL/SD3)": ["time_embed", "label_emb", "norm", "ln_", "out."] | |
| } | |
| def convert_and_upload(token, source_repo, target_repo, precision, target_components, arch_profile): | |
| if not token: | |
| yield "❌ Error: Please provide a valid Hugging Face Write Token." | |
| return | |
| if not target_repo.strip() or "/" not in target_repo: | |
| yield "❌ Error: Target Repository must be in format 'username/repo-name'." | |
| return | |
| if not target_components: | |
| yield "❌ Error: Please select at least one component to quantize." | |
| return | |
| # Map precision | |
| target_dtype = None | |
| is_int8 = precision == "INT8" | |
| if precision == "FP8": target_dtype = torch.float8_e4m3fn | |
| elif precision == "FP16": target_dtype = torch.float16 | |
| elif precision == "BF16": target_dtype = torch.bfloat16 | |
| api = HfApi(token=token) | |
| yield f"🔄 Verifying target repo: {target_repo}..." | |
| try: | |
| api.create_repo(repo_id=target_repo, exist_ok=True, private=False) | |
| except Exception as e: | |
| yield f"❌ Error creating repo: {str(e)}" | |
| return | |
| yield f"📋 Fetching files from {source_repo}..." | |
| try: | |
| files = api.list_repo_files(source_repo) | |
| except Exception as e: | |
| yield f"❌ Error fetching files: {str(e)}" | |
| return | |
| cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}" | |
| success_count, error_count = 0, 0 | |
| exclude_prefixes = ARCH_PROFILES.get(arch_profile, []) | |
| for file in files: | |
| if "/" not in file and file.endswith(".safetensors"): | |
| yield f"🗑️ Auto-skipping massive root model: {file}..." | |
| continue | |
| yield f"⏳ Processing {file}..." | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| local_path = hf_hub_download(repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token) | |
| in_target_component = any(f"{comp}/" in file for comp in target_components) | |
| if file.endswith(".safetensors") and in_target_component: | |
| yield f"🧠 Quantizing {file} to {precision}..." | |
| tensors = load_file(local_path) | |
| new_tensors = {} | |
| for k, v in tensors.items(): | |
| if is_int8: | |
| is_2d_weight = "weight" in k and len(v.shape) == 2 | |
| is_excluded = any(ex in k for ex in exclude_prefixes) | |
| if is_2d_weight and not is_excluded: | |
| if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16) | |
| scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0 | |
| scale = scale.clamp(min=1e-8) | |
| new_tensors[f"{k.rsplit('.', 1)[0]}.weight_int8"] = torch.round(v / scale).clamp(-127, 127).to(torch.int8) | |
| new_tensors[f"{k.rsplit('.', 1)[0]}.weight_scale"] = scale.to(torch.bfloat16) | |
| else: | |
| new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v | |
| else: | |
| new_tensors[k] = v.to(target_dtype) if v.is_floating_point() else v | |
| converted_path = "converted.safetensors" | |
| save_file(new_tensors, converted_path) | |
| del tensors, new_tensors | |
| gc.collect() | |
| yield f"☁️ Uploading {precision} version of {file}..." | |
| api.upload_file(path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo) | |
| os.remove(converted_path) | |
| else: | |
| yield f"☁️ Copying {file} as-is..." | |
| api.upload_file(path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo) | |
| success_count += 1 | |
| if os.path.exists(cache_dir): shutil.rmtree(cache_dir) | |
| gc.collect() | |
| except Exception as e: | |
| error_count += 1 | |
| yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..." | |
| if os.path.exists(cache_dir): shutil.rmtree(cache_dir) | |
| yield f"✅ Finished! Processed: {success_count} | Errors: {error_count}." | |
| # --- UI LOGIC --- | |
| def generate_target_repo(source, precision): | |
| model_name = source.split("/")[-1] if "/" in source else source | |
| return f"your-username/{model_name}-{precision}" | |
| def toggle_int8_warning(precision): | |
| return gr.update(visible=(precision == "INT8")) | |
| # --- GUI --- | |
| # FIXED: Removed the theme argument from gr.Blocks() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ⚡ Universal Model Quantizer Hub | |
| Convert massive diffusion and transformer models directly on the Hugging Face hub. | |
| Engineered with aggressive cache-clearing to prevent storage crashes on free-tier Spaces. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| with gr.Tabs(): | |
| with gr.TabItem("1. Authentication & Source"): | |
| hf_token = gr.Textbox(label="HF Access Token (Write)", type="password", placeholder="hf_...") | |
| source_repo = gr.Textbox( | |
| label="Source Repository", | |
| placeholder="e.g., black-forest-labs/FLUX.1-dev", | |
| info="Paste any Hugging Face model repository ID." | |
| ) | |
| gr.Markdown("### Popular Presets") | |
| with gr.Row(): | |
| preset_flux = gr.Button("FLUX.2-klein-9B", size="sm") | |
| preset_zimage = gr.Button("Z-Image-Turbo", size="sm") | |
| preset_sd3 = gr.Button("SD3.5-Large", size="sm") | |
| with gr.TabItem("2. Quantization Rules"): | |
| arch_profile = gr.Radio( | |
| choices=list(ARCH_PROFILES.keys()), | |
| value="FLUX / Generic Rectified Flow", | |
| label="Architecture Profile", | |
| info="Crucial for INT8: Selects which layers to protect from precision loss." | |
| ) | |
| target_components = gr.CheckboxGroup( | |
| choices=["transformer", "text_encoder", "text_encoder_2", "vae"], | |
| value=["transformer"], | |
| label="Folders to Quantize", | |
| info="Unselected folders will be copied to the new repo unchanged." | |
| ) | |
| with gr.TabItem("3. Output Settings"): | |
| precision = gr.Dropdown( | |
| choices=["FP8", "FP16", "BF16", "INT8"], | |
| value="INT8", | |
| label="Target Precision" | |
| ) | |
| int8_warning = gr.Markdown( | |
| "⚠️ **INT8 Selected:** Keys will be split into `weight_int8` and `weight_scale`. " | |
| "Requires custom XPU/CUDA native linear classes to execute.", | |
| visible=True | |
| ) | |
| target_repo = gr.Textbox( | |
| label="Target Repository", | |
| placeholder="your-username/model-name", | |
| interactive=True | |
| ) | |
| start_btn = gr.Button("🚀 Start Cloud Quantization", variant="primary", size="lg") | |
| with gr.Column(scale=4): | |
| output_log = gr.Textbox( | |
| label="Terminal Output", | |
| lines=24, | |
| interactive=False, | |
| max_lines=30 | |
| ) | |
| preset_flux.click(lambda: ("black-forest-labs/FLUX.2-klein-9B", "FLUX / Generic Rectified Flow"), outputs=[source_repo, arch_profile]) | |
| preset_zimage.click(lambda: ("your-username/Z-Image-Turbo", "Z-Image / DiT Core"), outputs=[source_repo, arch_profile]) | |
| preset_sd3.click(lambda: ("stabilityai/stable-diffusion-3.5-large", "Stable Diffusion (SDXL/SD3)"), outputs=[source_repo, arch_profile]) | |
| source_repo.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo]) | |
| precision.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo]) | |
| precision.change(fn=toggle_int8_warning, inputs=[precision], outputs=[int8_warning]) | |
| start_btn.click( | |
| fn=convert_and_upload, | |
| inputs=[hf_token, source_repo, target_repo, precision, target_components, arch_profile], | |
| outputs=[output_log] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate")) |