Spaces:
Running on Zero
Running on Zero
| import os | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download, HfApi | |
| from convert_to_quant import quantize | |
| def do_quantize(quant_args): | |
| quantize(**quant_args) | |
| def run_quantization( | |
| source_repo, | |
| source_file, | |
| target_repo, | |
| target_filename_base, | |
| quant_format, | |
| layer_filter, | |
| exclude_layers_regex, | |
| full_precision_matrix_mult, | |
| hf_token | |
| ): | |
| if not all([source_repo, source_file, target_repo, target_filename_base]): | |
| yield "Please fill in all repository and filename fields.", gr.update(visible=False) | |
| return | |
| try: | |
| # Download | |
| yield f"Downloading {source_file} from {source_repo}...", gr.update(visible=False) | |
| # We use token for download if provided, otherwise anonymous | |
| local_input_path = hf_hub_download(repo_id=source_repo, filename=source_file, token=hf_token if hf_token else None) | |
| # Setup quant arguments based on UI | |
| quant_args = { | |
| "input": local_input_path, | |
| "comfy_quant": True, | |
| "save_quant_metadata": True, | |
| "low_memory": True, | |
| "simple": True, | |
| "calib_samples": 40960 | |
| } | |
| suffix = "" | |
| if quant_format == "int8 rowwise": | |
| quant_args["int8"] = True | |
| quant_args["scaling_mode"] = "row" | |
| suffix = "-int8mixedrow-simple" | |
| elif quant_format == "mxfp8": | |
| quant_args["mxfp8"] = True | |
| suffix = "-mxfp8mixed-simple" | |
| else: # fp8 (default) | |
| quant_args["scaling_mode"] = "tensor" | |
| suffix = "-fp8mixed-simple" | |
| output_filename = f"{target_filename_base}{suffix}.safetensors" | |
| output_path = f"./{output_filename}" | |
| quant_args["output"] = output_path | |
| # Layer filters | |
| if layer_filter == "Anima": quant_args["anima"] = True | |
| elif layer_filter == "Microsoft Lens": quant_args["lens"] = True | |
| elif layer_filter == "Flux2": quant_args["flux2"] = True | |
| elif layer_filter == "Chroma": quant_args["distillation_large"] = True | |
| elif layer_filter == "Radiance": quant_args["nerf_large"] = True; quant_args["radiance"] = True | |
| elif layer_filter == "WAN": quant_args["wan"] = True | |
| elif layer_filter == "LTX-2.x": quant_args["ltxv2"] = True | |
| elif layer_filter == "Qwen Image": | |
| quant_args["qwen"] = True | |
| if full_precision_matrix_mult: quant_args["full_precision_matrix_mult"] = True | |
| elif layer_filter == "Z-Image": quant_args["zimage"] = True; quant_args["zimage_refiner"] = True | |
| if exclude_layers_regex: | |
| quant_args["exclude-layers"] = exclude_layers_regex | |
| yield f"Quantizing to {output_filename}...\nThis may take a few minutes.", gr.update(visible=False) | |
| do_quantize(quant_args) | |
| if hf_token: | |
| yield f"Uploading {output_filename} to {target_repo}...", gr.update(visible=False) | |
| # Upload | |
| api = HfApi(token=hf_token) | |
| commit_info = api.upload_file( | |
| path_or_fileobj=output_path, | |
| path_in_repo=output_filename, | |
| repo_id=target_repo, | |
| commit_message=f"Add {output_filename} quantized model", | |
| create_pr=True | |
| ) | |
| pr_url = commit_info.pr_url if hasattr(commit_info, 'pr_url') else f"https://huggingface.co/{target_repo}" | |
| yield f"Complete! Uploaded to {target_repo}", gr.update(value=f"<a href='{pr_url}' target='_blank' style='color: #3b82f6; text-decoration: underline; font-weight: bold;'>Click here to view the Pull Request</a>", visible=True) | |
| else: | |
| yield f"Complete! Ready for download below.", gr.update(value=output_path, visible=True) | |
| except Exception as e: | |
| yield f"Error: {str(e)}", gr.update(visible=False) | |
| # Build UI | |
| with gr.Blocks() as demo: | |
| with gr.Row(elem_id="topbar"): | |
| gr.Markdown("## 🤗 Model Quantizer", elem_classes=["brand"]) | |
| with gr.Row(elem_id="main-row", equal_height=True): | |
| with gr.Column(scale=4, min_width=280, elem_id="input-panel"): | |
| gr.Markdown("### Authentication (Optional)") | |
| hf_token = gr.Textbox(label="HF Token (WRITE)", type="password", placeholder="Paste your WRITE token for PR upload") | |
| gr.Markdown("*If no token is provided, the quantized model will be available for direct download instead of uploading as a PR.*", elem_classes=["text-sm"]) | |
| gr.Markdown("### Input Model") | |
| source_repo = gr.Textbox(label="Source HF Repo (e.g. author/model)") | |
| source_file = gr.Textbox(label="Source Filename (e.g. model.safetensors)") | |
| gr.Markdown("### Output Target") | |
| target_repo = gr.Textbox(label="Target HF Repo (e.g. author/model)") | |
| target_file_base = gr.Textbox(label="Target Filename Base (e.g. model-quant)") | |
| gr.Markdown("### Quantization Options") | |
| quant_format = gr.Radio(["fp8", "int8 rowwise", "mxfp8"], value="fp8", label="Format") | |
| layer_filter = gr.Dropdown( | |
| ["None", "Anima", "Microsoft Lens", "Flux2", "Chroma", "Radiance", "WAN", "LTX-2.x", "Qwen Image", "Z-Image"], | |
| value="None", label="Model Layer Filter" | |
| ) | |
| full_precision = gr.Checkbox(label="Qwen Image: Full precision matrix mult", visible=False) | |
| def update_qwen_visibility(selection): | |
| return gr.update(visible=selection == "Qwen Image") | |
| layer_filter.change(fn=update_qwen_visibility, inputs=layer_filter, outputs=full_precision) | |
| exclude_layers = gr.Textbox(label="Exclude Layers Regex (Optional)", placeholder="(substring_1|substring_2)") | |
| run_btn = gr.Button("Quantize Model", variant="primary", size="lg") | |
| gr.Markdown("ℹ️ *For more advanced quantization modes, install and use [convert-to-quant](https://pypi.org/project/convert-to-quant/) locally.*", elem_classes=["text-sm", "mt-4"]) | |
| with gr.Column(scale=8, elem_id="output-panel"): | |
| status_text = gr.Textbox(label="Status Log", lines=10, interactive=False) | |
| output_link = gr.HTML(visible=False) | |
| output_file = gr.File(label="Download Quantized Model", visible=False) | |
| def route_output(*args): | |
| # We handle routing the result to HTML vs File based on whether token is provided. | |
| # However, it's easier to just yield both appropriately. | |
| for status, result in run_quantization(*args): | |
| # If the result is a string with <a href=, it's HTML. Otherwise it's a file path. | |
| if result and isinstance(result, dict) and "value" in result: | |
| val = result["value"] | |
| if isinstance(val, str) and val.startswith("<a"): | |
| yield status, result, gr.update(visible=False) | |
| else: | |
| yield status, gr.update(visible=False), result | |
| else: | |
| yield status, gr.update(visible=False), gr.update(visible=False) | |
| run_btn.click( | |
| fn=route_output, | |
| inputs=[ | |
| source_repo, source_file, target_repo, target_file_base, | |
| quant_format, layer_filter, exclude_layers, full_precision, | |
| hf_token | |
| ], | |
| outputs=[status_text, output_link, output_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css_paths=["assets/responsive.css"], theme=gr.themes.Default(primary_hue="blue", neutral_hue="zinc")) | |