Spaces:
Build error
Build error
File size: 13,481 Bytes
c2c886b aa41047 c2c886b aa41047 c2c886b d10c339 f9d4d6e d10c339 aa41047 d10c339 28030ea aee2946 c2c886b 681221a c2c886b aa41047 c2c886b d10c339 c2c886b aaee40e c2c886b aaee40e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
"""
HuggingFace Space for model quantization using convert_to_quant.
Provides a Gradio interface for quantizing safetensors models to various
FP8/INT8 formats for ComfyUI inference, with HuggingFace Hub integration.
"""
import os
import tempfile
import gradio as gr
import spaces # ZeroGPU - required for space to function
from huggingface_hub import hf_hub_download, HfApi, create_commit, CommitOperationAdd
from convert_to_quant import convert, ConversionConfig
# HF Token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
# Model filter presets - simplified combined options
# Maps display name to list of filter flags to enable
MODEL_FILTER_PRESETS = {
# Text Encoders
"T5-XXL": {
"description": "T5-XXL text encoder: skip norms/biases, remove decoder layers",
"flags": ["t5xxl"],
},
"Mistral": {
"description": "Mistral text encoder exclusions",
"flags": ["mistral"],
},
"Visual Encoder": {
"description": "Visual encoder: skip MLP layers (down/up/gate proj)",
"flags": ["visual"],
},
# Diffusion Models
"Flux.2": {
"description": "Flux.2: keep modulation/guidance/time/final layers high-precision",
"flags": ["flux2"],
},
"Chroma": {
"description": "Chroma/distilled models: keep distilled_guidance, final, img/txt_in high-precision",
"flags": ["distillation_large"],
},
"Radiance": {
"description": "Radiance/NeRF models: keep nerf_blocks, img_in_patch, nerf_final_layer high-precision",
"flags": ["nerf_large", "radiance"],
},
# Video Models
"WAN Video": {
"description": "WAN video model: skip embeddings, encoders, head",
"flags": ["wan"],
},
"Hunyuan Video": {
"description": "Hunyuan Video 1.5: skip layernorm, attn norms, vision_in",
"flags": ["hunyuan"],
},
# Image Models
"Qwen Image": {
"description": "Qwen Image: skip added norms, keep time_text_embed high-precision",
"flags": ["qwen"],
},
"Z-Image": {
"description": "Z-Image models: skip cap_embedder/norms, keep x_embedder/final/refiners high-precision",
"flags": ["zimage", "zimage_refiner"],
},
}
# Quantization format options
# NOTE: MXFP8/NVFP4 temporarily disabled - see TODO.md
QUANT_FORMATS = {
"FP8 Tensorwise": {"format": "fp8", "scaling_mode": "tensor", "block_size": None},
"FP8 Block (128)": {"format": "fp8", "scaling_mode": "block", "block_size": 128},
"FP8 Block (64)": {"format": "fp8", "scaling_mode": "block", "block_size": 64},
"INT8 Block (128)": {"format": "int8", "scaling_mode": None, "block_size": 128},
}
def download_model_from_hub(source_repo: str, file_path: str) -> tuple[str | None, str]:
"""
Download a model file from HuggingFace Hub.
Args:
source_repo: Repository ID (username/repo_name)
file_path: Path to file within the repository
Returns:
Tuple of (local_path, status_message)
"""
if not source_repo or not file_path:
return None, "❌ Please provide both source repository and file path"
try:
local_path = hf_hub_download(
repo_id=source_repo,
filename=file_path,
token=HF_TOKEN,
local_dir=tempfile.gettempdir(),
)
return local_path, f"✅ Downloaded: `{file_path}` from `{source_repo}`"
except Exception as e:
return None, f"❌ Download failed: {str(e)}"
def upload_model_as_pr(
local_path: str,
target_repo: str,
target_path: str,
pr_title: str,
) -> str:
"""
Upload a model file to HuggingFace Hub as a Pull Request.
Args:
local_path: Path to local file
target_repo: Target repository ID (username/repo_name)
target_path: Path within the target repository
pr_title: Title for the pull request
Returns:
Status message with PR URL
"""
if not local_path or not os.path.exists(local_path):
return "❌ No file to upload"
if not target_repo or not target_path:
return "❌ Please provide target repository and path"
try:
api = HfApi(token=HF_TOKEN)
# Create commit as PR
commit_info = api.create_commit(
repo_id=target_repo,
operations=[
CommitOperationAdd(
path_in_repo=target_path,
path_or_fileobj=local_path,
)
],
commit_message=pr_title or "Add quantized model",
create_pr=True,
)
pr_url = commit_info.pr_url
return f"✅ Pull Request created: [{pr_url}]({pr_url})"
except Exception as e:
return f"❌ Upload failed: {str(e)}"
@spaces.GPU(duration=30)
def completion_signal():
"""Brief GPU allocation at completion to satisfy ZeroGPU requirements."""
return True
def run_quantization(config):
"""Run quantization (CPU-based, no GPU timeout)."""
return convert(config)
def quantize_model(
source_repo: str,
file_path: str,
quant_format: str,
model_preset: str,
exclude_layers_regex: str,
full_precision_matmul: bool,
target_repo: str,
target_path: str,
progress=gr.Progress(track_tqdm=True),
):
"""
Download, quantize, and optionally upload a model.
Args:
source_repo: Source HuggingFace repository (username/repo_name)
file_path: Path to model file in source repo
quant_format: Selected quantization format
model_preset: Model preset filter (or "None")
exclude_layers_regex: Regex pattern for layers to exclude
full_precision_matmul: Enable full precision matrix multiplication
target_repo: Target repository for upload (optional)
target_path: Target path in repository (optional)
Returns:
Tuple of (output_file_path, status_message)
"""
status_log = []
# Step 1: Download model from Hub
status_log.append("📥 **Downloading model...**")
input_path, download_status = download_model_from_hub(source_repo, file_path)
status_log.append(download_status)
if input_path is None:
return None, "\n\n".join(status_log)
# Step 2: Get format settings
format_config = QUANT_FORMATS.get(quant_format)
if not format_config:
status_log.append(f"❌ Unknown format: {quant_format}")
return None, "\n\n".join(status_log)
# Build filter flags from preset (can have multiple flags per preset)
filter_flags = {}
if model_preset and model_preset != "None":
preset_config = MODEL_FILTER_PRESETS.get(model_preset)
if preset_config:
for flag in preset_config.get("flags", []):
filter_flags[flag] = True
# Generate output filename
base_name = os.path.splitext(os.path.basename(input_path))[0]
format_suffix = quant_format.lower().replace(" ", "_").replace("(", "").replace(")", "")
output_name = f"{base_name}_{format_suffix}.safetensors"
output_path = os.path.join(tempfile.gettempdir(), output_name)
# Step 3: Build conversion config
status_log.append(f"⚙️ **Quantizing with {quant_format}...**")
config = ConversionConfig(
input_path=input_path,
output_path=output_path,
quant_format=format_config["format"],
comfy_quant=True,
save_quant_metadata=True,
simple=True,
verbose="VERBOSE",
scaling_mode=format_config.get("scaling_mode") or "tensor",
block_size=format_config.get("block_size"),
filter_flags=filter_flags,
exclude_layers=exclude_layers_regex if exclude_layers_regex and exclude_layers_regex.strip() else None,
full_precision_matrix_mult=full_precision_matmul,
force_cpu=True, # Bypass CUDA checks on ZeroGPU
)
try:
result = run_quantization(config)
if not result.success:
status_log.append(f"❌ Quantization failed: {result.error}")
return None, "\n\n".join(status_log)
status_log.append(f"✅ Quantization complete: `{os.path.basename(result.output_path)}`")
# Step 4: Upload to target repo if specified
if target_repo and target_repo.strip():
status_log.append("📤 **Uploading as Pull Request...**")
# Use provided target path or generate one
upload_path = target_path.strip() if target_path and target_path.strip() else output_name
pr_title = f"Add {quant_format} quantized model: {output_name}"
upload_status = upload_model_as_pr(
result.output_path,
target_repo.strip(),
upload_path,
pr_title,
)
status_log.append(upload_status)
# Brief GPU allocation at completion to satisfy ZeroGPU
completion_signal()
return result.output_path, "\n\n".join(status_log)
except Exception as e:
status_log.append(f"❌ Error: {str(e)}")
return None, "\n\n".join(status_log)
# Build Gradio interface
DESCRIPTION = """
# Model Quantization for ComfyUI
Quantize safetensors models to FP8/INT8 formats for efficient inference in ComfyUI.
## Workflow
1. Enter source repository and file path to download model
2. Select quantization format and model preset
3. Optionally configure target repo to upload as Pull Request
## Quantization Formats
- **FP8 Tensorwise**: Standard FP8 with per-tensor scaling (most compatible)
- **FP8 Block**: FP8 with per-block scaling (better accuracy)
- **MXFP8/NVFP4**: Next-gen formats (requires Blackwell GPU)
- **INT8 Block**: INT8 with per-block scaling (Triton-based)
"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 📥 Source Model")
source_repo = gr.Textbox(
label="Source Repository",
placeholder="username/model-repo",
info="HuggingFace repository ID",
)
file_path = gr.Textbox(
label="File Path in Repository",
placeholder="model.safetensors or path/to/model.safetensors",
info="Path to the safetensors file within the repo",
)
gr.Markdown("### ⚙️ Quantization Settings")
quant_format = gr.Dropdown(
label="Quantization Format",
choices=list(QUANT_FORMATS.keys()),
value="FP8 Tensorwise",
)
model_preset = gr.Dropdown(
label="Model Preset",
choices=["None"] + list(MODEL_FILTER_PRESETS.keys()),
value="None",
info="Preset layer exclusions for specific architectures",
)
with gr.Accordion("Advanced Options", open=False):
exclude_layers = gr.Textbox(
label="Exclude Layers (Regex)",
placeholder="e.g., img_in|txt_in|final_layer",
info="Regex pattern for layers to keep in original precision",
)
full_precision_matmul = gr.Checkbox(
label="Full Precision Matrix Multiply",
value=False,
info="Use FP32 matmul (for storage-only quantization)",
)
gr.Markdown("### 📤 Upload Target (Optional)")
target_repo = gr.Textbox(
label="Target Repository",
placeholder="username/target-repo",
info="Leave empty to skip upload",
)
target_path = gr.Textbox(
label="Target Path in Repository",
placeholder="quantized/model_fp8.safetensors",
info="Path where file will be uploaded (optional)",
)
quantize_btn = gr.Button("🚀 Quantize Model", variant="primary", size="lg")
with gr.Column(scale=1):
output_file = gr.File(
label="Download Quantized Model",
interactive=False,
)
status = gr.Markdown(label="Status")
# Show preset description when selected
preset_info = gr.Markdown("")
# Update preset info when selection changes
def update_preset_info(preset):
if preset and preset != "None":
preset_config = MODEL_FILTER_PRESETS.get(preset, {})
desc = preset_config.get("description", "")
return f"**{preset}**: {desc}"
return ""
model_preset.change(
fn=update_preset_info,
inputs=[model_preset],
outputs=[preset_info],
)
# Run quantization
quantize_btn.click(
fn=quantize_model,
inputs=[
source_repo,
file_path,
quant_format,
model_preset,
exclude_layers,
full_precision_matmul,
target_repo,
target_path,
],
outputs=[output_file, status],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())
|