File size: 7,783 Bytes
cd58174
 
 
 
 
 
 
de0e007
 
 
cd58174
 
 
 
 
 
 
 
 
13e5408
cd58174
 
13e5408
cd58174
 
 
 
13e5408
 
 
cd58174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e5408
cd58174
de0e007
cd58174
13e5408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd58174
 
13e5408
cd58174
 
 
13e5408
cd58174
 
 
 
 
13e5408
 
 
 
cd58174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e5408
cd58174
 
 
 
 
13e5408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd58174
 
13e5408
cd58174
 
13e5408
 
cd58174
13e5408
cd58174
 
 
dbdd7bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download, HfApi
from convert_to_quant import quantize

@spaces.GPU(duration=300)
def do_quantize(quant_args):
    quantize(**quant_args)

def run_quantization(

    source_repo,

    source_file,

    target_repo,

    target_filename_base,

    quant_format,

    layer_filter,

    exclude_layers_regex,

    full_precision_matrix_mult,

    hf_token

):
    if not all([source_repo, source_file, target_repo, target_filename_base]):
        yield "Please fill in all repository and filename fields.", gr.update(visible=False)
        return

    try:
        # Download
        yield f"Downloading {source_file} from {source_repo}...", gr.update(visible=False)
        # We use token for download if provided, otherwise anonymous
        local_input_path = hf_hub_download(repo_id=source_repo, filename=source_file, token=hf_token if hf_token else None)

        # Setup quant arguments based on UI
        quant_args = {
            "input": local_input_path,
            "comfy_quant": True,
            "save_quant_metadata": True,
            "low_memory": True,
            "simple": True,
            "calib_samples": 40960
        }

        suffix = ""
        if quant_format == "int8 rowwise":
            quant_args["int8"] = True
            quant_args["scaling_mode"] = "row"
            suffix = "-int8mixedrow-simple"
        elif quant_format == "mxfp8":
            quant_args["mxfp8"] = True
            suffix = "-mxfp8mixed-simple"
        else: # fp8 (default)
            quant_args["scaling_mode"] = "tensor"
            suffix = "-fp8mixed-simple"

        output_filename = f"{target_filename_base}{suffix}.safetensors"
        output_path = f"./{output_filename}"
        quant_args["output"] = output_path

        # Layer filters
        if layer_filter == "Anima": quant_args["anima"] = True
        elif layer_filter == "Microsoft Lens": quant_args["lens"] = True
        elif layer_filter == "Flux2": quant_args["flux2"] = True
        elif layer_filter == "Chroma": quant_args["distillation_large"] = True
        elif layer_filter == "Radiance": quant_args["nerf_large"] = True; quant_args["radiance"] = True
        elif layer_filter == "WAN": quant_args["wan"] = True
        elif layer_filter == "LTX-2.x": quant_args["ltxv2"] = True
        elif layer_filter == "Qwen Image":
            quant_args["qwen"] = True
            if full_precision_matrix_mult: quant_args["full_precision_matrix_mult"] = True
        elif layer_filter == "Z-Image": quant_args["zimage"] = True; quant_args["zimage_refiner"] = True

        if exclude_layers_regex:
            quant_args["exclude-layers"] = exclude_layers_regex

        yield f"Quantizing to {output_filename}...\nThis may take a few minutes.", gr.update(visible=False)

        do_quantize(quant_args)

        if hf_token:
            yield f"Uploading {output_filename} to {target_repo}...", gr.update(visible=False)
            
            # Upload
            api = HfApi(token=hf_token)
            
            commit_info = api.upload_file(
                path_or_fileobj=output_path,
                path_in_repo=output_filename,
                repo_id=target_repo,
                commit_message=f"Add {output_filename} quantized model",
                create_pr=True
            )
            
            pr_url = commit_info.pr_url if hasattr(commit_info, 'pr_url') else f"https://huggingface.co/{target_repo}"
            
            yield f"Complete! Uploaded to {target_repo}", gr.update(value=f"<a href='{pr_url}' target='_blank' style='color: #3b82f6; text-decoration: underline; font-weight: bold;'>Click here to view the Pull Request</a>", visible=True)
        else:
            yield f"Complete! Ready for download below.", gr.update(value=output_path, visible=True)

    except Exception as e:
        yield f"Error: {str(e)}", gr.update(visible=False)


# Build UI
with gr.Blocks() as demo:
    with gr.Row(elem_id="topbar"):
        gr.Markdown("## 🤗 Model Quantizer", elem_classes=["brand"])

    with gr.Row(elem_id="main-row", equal_height=True):
        with gr.Column(scale=4, min_width=280, elem_id="input-panel"):
            gr.Markdown("### Authentication (Optional)")
            hf_token = gr.Textbox(label="HF Token (WRITE)", type="password", placeholder="Paste your WRITE token for PR upload")
            gr.Markdown("*If no token is provided, the quantized model will be available for direct download instead of uploading as a PR.*", elem_classes=["text-sm"])
            
            gr.Markdown("### Input Model")
            source_repo = gr.Textbox(label="Source HF Repo (e.g. author/model)")
            source_file = gr.Textbox(label="Source Filename (e.g. model.safetensors)")

            gr.Markdown("### Output Target")
            target_repo = gr.Textbox(label="Target HF Repo (e.g. author/model)")
            target_file_base = gr.Textbox(label="Target Filename Base (e.g. model-quant)")

            gr.Markdown("### Quantization Options")
            quant_format = gr.Radio(["fp8", "int8 rowwise", "mxfp8"], value="fp8", label="Format")

            layer_filter = gr.Dropdown(
                ["None", "Anima", "Microsoft Lens", "Flux2", "Chroma", "Radiance", "WAN", "LTX-2.x", "Qwen Image", "Z-Image"],
                value="None", label="Model Layer Filter"
            )
            full_precision = gr.Checkbox(label="Qwen Image: Full precision matrix mult", visible=False)

            def update_qwen_visibility(selection):
                return gr.update(visible=selection == "Qwen Image")
            layer_filter.change(fn=update_qwen_visibility, inputs=layer_filter, outputs=full_precision)

            exclude_layers = gr.Textbox(label="Exclude Layers Regex (Optional)", placeholder="(substring_1|substring_2)")

            run_btn = gr.Button("Quantize Model", variant="primary", size="lg")

            gr.Markdown("ℹ️ *For more advanced quantization modes, install and use [convert-to-quant](https://pypi.org/project/convert-to-quant/) locally.*", elem_classes=["text-sm", "mt-4"])

        with gr.Column(scale=8, elem_id="output-panel"):
            status_text = gr.Textbox(label="Status Log", lines=10, interactive=False)
            output_link = gr.HTML(visible=False)
            output_file = gr.File(label="Download Quantized Model", visible=False)

    def route_output(*args):
        # We handle routing the result to HTML vs File based on whether token is provided.
        # However, it's easier to just yield both appropriately.
        for status, result in run_quantization(*args):
            # If the result is a string with <a href=, it's HTML. Otherwise it's a file path.
            if result and isinstance(result, dict) and "value" in result:
                val = result["value"]
                if isinstance(val, str) and val.startswith("<a"):
                    yield status, result, gr.update(visible=False)
                else:
                    yield status, gr.update(visible=False), result
            else:
                yield status, gr.update(visible=False), gr.update(visible=False)

    run_btn.click(
        fn=route_output,
        inputs=[
            source_repo, source_file, target_repo, target_file_base,
            quant_format, layer_filter, exclude_layers, full_precision,
            hf_token
        ],
        outputs=[status_text, output_link, output_file]
    )

if __name__ == "__main__":
    demo.launch(css_paths=["assets/responsive.css"], theme=gr.themes.Default(primary_hue="blue", neutral_hue="zinc"))