Spaces:
Running on Zero
Running on Zero
File size: 7,783 Bytes
cd58174 de0e007 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 de0e007 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 13e5408 cd58174 dbdd7bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download, HfApi
from convert_to_quant import quantize
@spaces.GPU(duration=300)
def do_quantize(quant_args):
quantize(**quant_args)
def run_quantization(
source_repo,
source_file,
target_repo,
target_filename_base,
quant_format,
layer_filter,
exclude_layers_regex,
full_precision_matrix_mult,
hf_token
):
if not all([source_repo, source_file, target_repo, target_filename_base]):
yield "Please fill in all repository and filename fields.", gr.update(visible=False)
return
try:
# Download
yield f"Downloading {source_file} from {source_repo}...", gr.update(visible=False)
# We use token for download if provided, otherwise anonymous
local_input_path = hf_hub_download(repo_id=source_repo, filename=source_file, token=hf_token if hf_token else None)
# Setup quant arguments based on UI
quant_args = {
"input": local_input_path,
"comfy_quant": True,
"save_quant_metadata": True,
"low_memory": True,
"simple": True,
"calib_samples": 40960
}
suffix = ""
if quant_format == "int8 rowwise":
quant_args["int8"] = True
quant_args["scaling_mode"] = "row"
suffix = "-int8mixedrow-simple"
elif quant_format == "mxfp8":
quant_args["mxfp8"] = True
suffix = "-mxfp8mixed-simple"
else: # fp8 (default)
quant_args["scaling_mode"] = "tensor"
suffix = "-fp8mixed-simple"
output_filename = f"{target_filename_base}{suffix}.safetensors"
output_path = f"./{output_filename}"
quant_args["output"] = output_path
# Layer filters
if layer_filter == "Anima": quant_args["anima"] = True
elif layer_filter == "Microsoft Lens": quant_args["lens"] = True
elif layer_filter == "Flux2": quant_args["flux2"] = True
elif layer_filter == "Chroma": quant_args["distillation_large"] = True
elif layer_filter == "Radiance": quant_args["nerf_large"] = True; quant_args["radiance"] = True
elif layer_filter == "WAN": quant_args["wan"] = True
elif layer_filter == "LTX-2.x": quant_args["ltxv2"] = True
elif layer_filter == "Qwen Image":
quant_args["qwen"] = True
if full_precision_matrix_mult: quant_args["full_precision_matrix_mult"] = True
elif layer_filter == "Z-Image": quant_args["zimage"] = True; quant_args["zimage_refiner"] = True
if exclude_layers_regex:
quant_args["exclude-layers"] = exclude_layers_regex
yield f"Quantizing to {output_filename}...\nThis may take a few minutes.", gr.update(visible=False)
do_quantize(quant_args)
if hf_token:
yield f"Uploading {output_filename} to {target_repo}...", gr.update(visible=False)
# Upload
api = HfApi(token=hf_token)
commit_info = api.upload_file(
path_or_fileobj=output_path,
path_in_repo=output_filename,
repo_id=target_repo,
commit_message=f"Add {output_filename} quantized model",
create_pr=True
)
pr_url = commit_info.pr_url if hasattr(commit_info, 'pr_url') else f"https://huggingface.co/{target_repo}"
yield f"Complete! Uploaded to {target_repo}", gr.update(value=f"<a href='{pr_url}' target='_blank' style='color: #3b82f6; text-decoration: underline; font-weight: bold;'>Click here to view the Pull Request</a>", visible=True)
else:
yield f"Complete! Ready for download below.", gr.update(value=output_path, visible=True)
except Exception as e:
yield f"Error: {str(e)}", gr.update(visible=False)
# Build UI
with gr.Blocks() as demo:
with gr.Row(elem_id="topbar"):
gr.Markdown("## 🤗 Model Quantizer", elem_classes=["brand"])
with gr.Row(elem_id="main-row", equal_height=True):
with gr.Column(scale=4, min_width=280, elem_id="input-panel"):
gr.Markdown("### Authentication (Optional)")
hf_token = gr.Textbox(label="HF Token (WRITE)", type="password", placeholder="Paste your WRITE token for PR upload")
gr.Markdown("*If no token is provided, the quantized model will be available for direct download instead of uploading as a PR.*", elem_classes=["text-sm"])
gr.Markdown("### Input Model")
source_repo = gr.Textbox(label="Source HF Repo (e.g. author/model)")
source_file = gr.Textbox(label="Source Filename (e.g. model.safetensors)")
gr.Markdown("### Output Target")
target_repo = gr.Textbox(label="Target HF Repo (e.g. author/model)")
target_file_base = gr.Textbox(label="Target Filename Base (e.g. model-quant)")
gr.Markdown("### Quantization Options")
quant_format = gr.Radio(["fp8", "int8 rowwise", "mxfp8"], value="fp8", label="Format")
layer_filter = gr.Dropdown(
["None", "Anima", "Microsoft Lens", "Flux2", "Chroma", "Radiance", "WAN", "LTX-2.x", "Qwen Image", "Z-Image"],
value="None", label="Model Layer Filter"
)
full_precision = gr.Checkbox(label="Qwen Image: Full precision matrix mult", visible=False)
def update_qwen_visibility(selection):
return gr.update(visible=selection == "Qwen Image")
layer_filter.change(fn=update_qwen_visibility, inputs=layer_filter, outputs=full_precision)
exclude_layers = gr.Textbox(label="Exclude Layers Regex (Optional)", placeholder="(substring_1|substring_2)")
run_btn = gr.Button("Quantize Model", variant="primary", size="lg")
gr.Markdown("ℹ️ *For more advanced quantization modes, install and use [convert-to-quant](https://pypi.org/project/convert-to-quant/) locally.*", elem_classes=["text-sm", "mt-4"])
with gr.Column(scale=8, elem_id="output-panel"):
status_text = gr.Textbox(label="Status Log", lines=10, interactive=False)
output_link = gr.HTML(visible=False)
output_file = gr.File(label="Download Quantized Model", visible=False)
def route_output(*args):
# We handle routing the result to HTML vs File based on whether token is provided.
# However, it's easier to just yield both appropriately.
for status, result in run_quantization(*args):
# If the result is a string with <a href=, it's HTML. Otherwise it's a file path.
if result and isinstance(result, dict) and "value" in result:
val = result["value"]
if isinstance(val, str) and val.startswith("<a"):
yield status, result, gr.update(visible=False)
else:
yield status, gr.update(visible=False), result
else:
yield status, gr.update(visible=False), gr.update(visible=False)
run_btn.click(
fn=route_output,
inputs=[
source_repo, source_file, target_repo, target_file_base,
quant_format, layer_filter, exclude_layers, full_precision,
hf_token
],
outputs=[status_text, output_link, output_file]
)
if __name__ == "__main__":
demo.launch(css_paths=["assets/responsive.css"], theme=gr.themes.Default(primary_hue="blue", neutral_hue="zinc"))
|