Quick-Quantize / app.py
silveroxides's picture
Upload folder using huggingface_hub
dbdd7bc verified
import os
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download, HfApi
from convert_to_quant import quantize
@spaces.GPU(duration=300)
def do_quantize(quant_args):
quantize(**quant_args)
def run_quantization(
source_repo,
source_file,
target_repo,
target_filename_base,
quant_format,
layer_filter,
exclude_layers_regex,
full_precision_matrix_mult,
hf_token
):
if not all([source_repo, source_file, target_repo, target_filename_base]):
yield "Please fill in all repository and filename fields.", gr.update(visible=False)
return
try:
# Download
yield f"Downloading {source_file} from {source_repo}...", gr.update(visible=False)
# We use token for download if provided, otherwise anonymous
local_input_path = hf_hub_download(repo_id=source_repo, filename=source_file, token=hf_token if hf_token else None)
# Setup quant arguments based on UI
quant_args = {
"input": local_input_path,
"comfy_quant": True,
"save_quant_metadata": True,
"low_memory": True,
"simple": True,
"calib_samples": 40960
}
suffix = ""
if quant_format == "int8 rowwise":
quant_args["int8"] = True
quant_args["scaling_mode"] = "row"
suffix = "-int8mixedrow-simple"
elif quant_format == "mxfp8":
quant_args["mxfp8"] = True
suffix = "-mxfp8mixed-simple"
else: # fp8 (default)
quant_args["scaling_mode"] = "tensor"
suffix = "-fp8mixed-simple"
output_filename = f"{target_filename_base}{suffix}.safetensors"
output_path = f"./{output_filename}"
quant_args["output"] = output_path
# Layer filters
if layer_filter == "Anima": quant_args["anima"] = True
elif layer_filter == "Microsoft Lens": quant_args["lens"] = True
elif layer_filter == "Flux2": quant_args["flux2"] = True
elif layer_filter == "Chroma": quant_args["distillation_large"] = True
elif layer_filter == "Radiance": quant_args["nerf_large"] = True; quant_args["radiance"] = True
elif layer_filter == "WAN": quant_args["wan"] = True
elif layer_filter == "LTX-2.x": quant_args["ltxv2"] = True
elif layer_filter == "Qwen Image":
quant_args["qwen"] = True
if full_precision_matrix_mult: quant_args["full_precision_matrix_mult"] = True
elif layer_filter == "Z-Image": quant_args["zimage"] = True; quant_args["zimage_refiner"] = True
if exclude_layers_regex:
quant_args["exclude-layers"] = exclude_layers_regex
yield f"Quantizing to {output_filename}...\nThis may take a few minutes.", gr.update(visible=False)
do_quantize(quant_args)
if hf_token:
yield f"Uploading {output_filename} to {target_repo}...", gr.update(visible=False)
# Upload
api = HfApi(token=hf_token)
commit_info = api.upload_file(
path_or_fileobj=output_path,
path_in_repo=output_filename,
repo_id=target_repo,
commit_message=f"Add {output_filename} quantized model",
create_pr=True
)
pr_url = commit_info.pr_url if hasattr(commit_info, 'pr_url') else f"https://huggingface.co/{target_repo}"
yield f"Complete! Uploaded to {target_repo}", gr.update(value=f"<a href='{pr_url}' target='_blank' style='color: #3b82f6; text-decoration: underline; font-weight: bold;'>Click here to view the Pull Request</a>", visible=True)
else:
yield f"Complete! Ready for download below.", gr.update(value=output_path, visible=True)
except Exception as e:
yield f"Error: {str(e)}", gr.update(visible=False)
# Build UI
with gr.Blocks() as demo:
with gr.Row(elem_id="topbar"):
gr.Markdown("## 🤗 Model Quantizer", elem_classes=["brand"])
with gr.Row(elem_id="main-row", equal_height=True):
with gr.Column(scale=4, min_width=280, elem_id="input-panel"):
gr.Markdown("### Authentication (Optional)")
hf_token = gr.Textbox(label="HF Token (WRITE)", type="password", placeholder="Paste your WRITE token for PR upload")
gr.Markdown("*If no token is provided, the quantized model will be available for direct download instead of uploading as a PR.*", elem_classes=["text-sm"])
gr.Markdown("### Input Model")
source_repo = gr.Textbox(label="Source HF Repo (e.g. author/model)")
source_file = gr.Textbox(label="Source Filename (e.g. model.safetensors)")
gr.Markdown("### Output Target")
target_repo = gr.Textbox(label="Target HF Repo (e.g. author/model)")
target_file_base = gr.Textbox(label="Target Filename Base (e.g. model-quant)")
gr.Markdown("### Quantization Options")
quant_format = gr.Radio(["fp8", "int8 rowwise", "mxfp8"], value="fp8", label="Format")
layer_filter = gr.Dropdown(
["None", "Anima", "Microsoft Lens", "Flux2", "Chroma", "Radiance", "WAN", "LTX-2.x", "Qwen Image", "Z-Image"],
value="None", label="Model Layer Filter"
)
full_precision = gr.Checkbox(label="Qwen Image: Full precision matrix mult", visible=False)
def update_qwen_visibility(selection):
return gr.update(visible=selection == "Qwen Image")
layer_filter.change(fn=update_qwen_visibility, inputs=layer_filter, outputs=full_precision)
exclude_layers = gr.Textbox(label="Exclude Layers Regex (Optional)", placeholder="(substring_1|substring_2)")
run_btn = gr.Button("Quantize Model", variant="primary", size="lg")
gr.Markdown("ℹ️ *For more advanced quantization modes, install and use [convert-to-quant](https://pypi.org/project/convert-to-quant/) locally.*", elem_classes=["text-sm", "mt-4"])
with gr.Column(scale=8, elem_id="output-panel"):
status_text = gr.Textbox(label="Status Log", lines=10, interactive=False)
output_link = gr.HTML(visible=False)
output_file = gr.File(label="Download Quantized Model", visible=False)
def route_output(*args):
# We handle routing the result to HTML vs File based on whether token is provided.
# However, it's easier to just yield both appropriately.
for status, result in run_quantization(*args):
# If the result is a string with <a href=, it's HTML. Otherwise it's a file path.
if result and isinstance(result, dict) and "value" in result:
val = result["value"]
if isinstance(val, str) and val.startswith("<a"):
yield status, result, gr.update(visible=False)
else:
yield status, gr.update(visible=False), result
else:
yield status, gr.update(visible=False), gr.update(visible=False)
run_btn.click(
fn=route_output,
inputs=[
source_repo, source_file, target_repo, target_file_base,
quant_format, layer_filter, exclude_layers, full_precision,
hf_token
],
outputs=[status_text, output_link, output_file]
)
if __name__ == "__main__":
demo.launch(css_paths=["assets/responsive.css"], theme=gr.themes.Default(primary_hue="blue", neutral_hue="zinc"))