silveroxides's picture
fix: Remove torch.cuda.is_available() to prevent CUDA init in main process
f9d4d6e
"""
HuggingFace Space for model quantization using convert_to_quant.
Provides a Gradio interface for quantizing safetensors models to various
FP8/INT8 formats for ComfyUI inference, with HuggingFace Hub integration.
"""
import os
import tempfile
import gradio as gr
import spaces # ZeroGPU - required for space to function
from huggingface_hub import hf_hub_download, HfApi, create_commit, CommitOperationAdd
from convert_to_quant import convert, ConversionConfig
# HF Token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
# Model filter presets - simplified combined options
# Maps display name to list of filter flags to enable
MODEL_FILTER_PRESETS = {
# Text Encoders
"T5-XXL": {
"description": "T5-XXL text encoder: skip norms/biases, remove decoder layers",
"flags": ["t5xxl"],
},
"Mistral": {
"description": "Mistral text encoder exclusions",
"flags": ["mistral"],
},
"Visual Encoder": {
"description": "Visual encoder: skip MLP layers (down/up/gate proj)",
"flags": ["visual"],
},
# Diffusion Models
"Flux.2": {
"description": "Flux.2: keep modulation/guidance/time/final layers high-precision",
"flags": ["flux2"],
},
"Chroma": {
"description": "Chroma/distilled models: keep distilled_guidance, final, img/txt_in high-precision",
"flags": ["distillation_large"],
},
"Radiance": {
"description": "Radiance/NeRF models: keep nerf_blocks, img_in_patch, nerf_final_layer high-precision",
"flags": ["nerf_large", "radiance"],
},
# Video Models
"WAN Video": {
"description": "WAN video model: skip embeddings, encoders, head",
"flags": ["wan"],
},
"Hunyuan Video": {
"description": "Hunyuan Video 1.5: skip layernorm, attn norms, vision_in",
"flags": ["hunyuan"],
},
# Image Models
"Qwen Image": {
"description": "Qwen Image: skip added norms, keep time_text_embed high-precision",
"flags": ["qwen"],
},
"Z-Image": {
"description": "Z-Image models: skip cap_embedder/norms, keep x_embedder/final/refiners high-precision",
"flags": ["zimage", "zimage_refiner"],
},
}
# Quantization format options
# NOTE: MXFP8/NVFP4 temporarily disabled - see TODO.md
QUANT_FORMATS = {
"FP8 Tensorwise": {"format": "fp8", "scaling_mode": "tensor", "block_size": None},
"FP8 Block (128)": {"format": "fp8", "scaling_mode": "block", "block_size": 128},
"FP8 Block (64)": {"format": "fp8", "scaling_mode": "block", "block_size": 64},
"INT8 Block (128)": {"format": "int8", "scaling_mode": None, "block_size": 128},
}
def download_model_from_hub(source_repo: str, file_path: str) -> tuple[str | None, str]:
"""
Download a model file from HuggingFace Hub.
Args:
source_repo: Repository ID (username/repo_name)
file_path: Path to file within the repository
Returns:
Tuple of (local_path, status_message)
"""
if not source_repo or not file_path:
return None, "❌ Please provide both source repository and file path"
try:
local_path = hf_hub_download(
repo_id=source_repo,
filename=file_path,
token=HF_TOKEN,
local_dir=tempfile.gettempdir(),
)
return local_path, f"✅ Downloaded: `{file_path}` from `{source_repo}`"
except Exception as e:
return None, f"❌ Download failed: {str(e)}"
def upload_model_as_pr(
local_path: str,
target_repo: str,
target_path: str,
pr_title: str,
) -> str:
"""
Upload a model file to HuggingFace Hub as a Pull Request.
Args:
local_path: Path to local file
target_repo: Target repository ID (username/repo_name)
target_path: Path within the target repository
pr_title: Title for the pull request
Returns:
Status message with PR URL
"""
if not local_path or not os.path.exists(local_path):
return "❌ No file to upload"
if not target_repo or not target_path:
return "❌ Please provide target repository and path"
try:
api = HfApi(token=HF_TOKEN)
# Create commit as PR
commit_info = api.create_commit(
repo_id=target_repo,
operations=[
CommitOperationAdd(
path_in_repo=target_path,
path_or_fileobj=local_path,
)
],
commit_message=pr_title or "Add quantized model",
create_pr=True,
)
pr_url = commit_info.pr_url
return f"✅ Pull Request created: [{pr_url}]({pr_url})"
except Exception as e:
return f"❌ Upload failed: {str(e)}"
@spaces.GPU(duration=30)
def completion_signal():
"""Brief GPU allocation at completion to satisfy ZeroGPU requirements."""
return True
def run_quantization(config):
"""Run quantization (CPU-based, no GPU timeout)."""
return convert(config)
def quantize_model(
source_repo: str,
file_path: str,
quant_format: str,
model_preset: str,
exclude_layers_regex: str,
full_precision_matmul: bool,
target_repo: str,
target_path: str,
progress=gr.Progress(track_tqdm=True),
):
"""
Download, quantize, and optionally upload a model.
Args:
source_repo: Source HuggingFace repository (username/repo_name)
file_path: Path to model file in source repo
quant_format: Selected quantization format
model_preset: Model preset filter (or "None")
exclude_layers_regex: Regex pattern for layers to exclude
full_precision_matmul: Enable full precision matrix multiplication
target_repo: Target repository for upload (optional)
target_path: Target path in repository (optional)
Returns:
Tuple of (output_file_path, status_message)
"""
status_log = []
# Step 1: Download model from Hub
status_log.append("📥 **Downloading model...**")
input_path, download_status = download_model_from_hub(source_repo, file_path)
status_log.append(download_status)
if input_path is None:
return None, "\n\n".join(status_log)
# Step 2: Get format settings
format_config = QUANT_FORMATS.get(quant_format)
if not format_config:
status_log.append(f"❌ Unknown format: {quant_format}")
return None, "\n\n".join(status_log)
# Build filter flags from preset (can have multiple flags per preset)
filter_flags = {}
if model_preset and model_preset != "None":
preset_config = MODEL_FILTER_PRESETS.get(model_preset)
if preset_config:
for flag in preset_config.get("flags", []):
filter_flags[flag] = True
# Generate output filename
base_name = os.path.splitext(os.path.basename(input_path))[0]
format_suffix = quant_format.lower().replace(" ", "_").replace("(", "").replace(")", "")
output_name = f"{base_name}_{format_suffix}.safetensors"
output_path = os.path.join(tempfile.gettempdir(), output_name)
# Step 3: Build conversion config
status_log.append(f"⚙️ **Quantizing with {quant_format}...**")
config = ConversionConfig(
input_path=input_path,
output_path=output_path,
quant_format=format_config["format"],
comfy_quant=True,
save_quant_metadata=True,
simple=True,
verbose="VERBOSE",
scaling_mode=format_config.get("scaling_mode") or "tensor",
block_size=format_config.get("block_size"),
filter_flags=filter_flags,
exclude_layers=exclude_layers_regex if exclude_layers_regex and exclude_layers_regex.strip() else None,
full_precision_matrix_mult=full_precision_matmul,
force_cpu=True, # Bypass CUDA checks on ZeroGPU
)
try:
result = run_quantization(config)
if not result.success:
status_log.append(f"❌ Quantization failed: {result.error}")
return None, "\n\n".join(status_log)
status_log.append(f"✅ Quantization complete: `{os.path.basename(result.output_path)}`")
# Step 4: Upload to target repo if specified
if target_repo and target_repo.strip():
status_log.append("📤 **Uploading as Pull Request...**")
# Use provided target path or generate one
upload_path = target_path.strip() if target_path and target_path.strip() else output_name
pr_title = f"Add {quant_format} quantized model: {output_name}"
upload_status = upload_model_as_pr(
result.output_path,
target_repo.strip(),
upload_path,
pr_title,
)
status_log.append(upload_status)
# Brief GPU allocation at completion to satisfy ZeroGPU
completion_signal()
return result.output_path, "\n\n".join(status_log)
except Exception as e:
status_log.append(f"❌ Error: {str(e)}")
return None, "\n\n".join(status_log)
# Build Gradio interface
DESCRIPTION = """
# Model Quantization for ComfyUI
Quantize safetensors models to FP8/INT8 formats for efficient inference in ComfyUI.
## Workflow
1. Enter source repository and file path to download model
2. Select quantization format and model preset
3. Optionally configure target repo to upload as Pull Request
## Quantization Formats
- **FP8 Tensorwise**: Standard FP8 with per-tensor scaling (most compatible)
- **FP8 Block**: FP8 with per-block scaling (better accuracy)
- **MXFP8/NVFP4**: Next-gen formats (requires Blackwell GPU)
- **INT8 Block**: INT8 with per-block scaling (Triton-based)
"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 📥 Source Model")
source_repo = gr.Textbox(
label="Source Repository",
placeholder="username/model-repo",
info="HuggingFace repository ID",
)
file_path = gr.Textbox(
label="File Path in Repository",
placeholder="model.safetensors or path/to/model.safetensors",
info="Path to the safetensors file within the repo",
)
gr.Markdown("### ⚙️ Quantization Settings")
quant_format = gr.Dropdown(
label="Quantization Format",
choices=list(QUANT_FORMATS.keys()),
value="FP8 Tensorwise",
)
model_preset = gr.Dropdown(
label="Model Preset",
choices=["None"] + list(MODEL_FILTER_PRESETS.keys()),
value="None",
info="Preset layer exclusions for specific architectures",
)
with gr.Accordion("Advanced Options", open=False):
exclude_layers = gr.Textbox(
label="Exclude Layers (Regex)",
placeholder="e.g., img_in|txt_in|final_layer",
info="Regex pattern for layers to keep in original precision",
)
full_precision_matmul = gr.Checkbox(
label="Full Precision Matrix Multiply",
value=False,
info="Use FP32 matmul (for storage-only quantization)",
)
gr.Markdown("### 📤 Upload Target (Optional)")
target_repo = gr.Textbox(
label="Target Repository",
placeholder="username/target-repo",
info="Leave empty to skip upload",
)
target_path = gr.Textbox(
label="Target Path in Repository",
placeholder="quantized/model_fp8.safetensors",
info="Path where file will be uploaded (optional)",
)
quantize_btn = gr.Button("🚀 Quantize Model", variant="primary", size="lg")
with gr.Column(scale=1):
output_file = gr.File(
label="Download Quantized Model",
interactive=False,
)
status = gr.Markdown(label="Status")
# Show preset description when selected
preset_info = gr.Markdown("")
# Update preset info when selection changes
def update_preset_info(preset):
if preset and preset != "None":
preset_config = MODEL_FILTER_PRESETS.get(preset, {})
desc = preset_config.get("description", "")
return f"**{preset}**: {desc}"
return ""
model_preset.change(
fn=update_preset_info,
inputs=[model_preset],
outputs=[preset_info],
)
# Run quantization
quantize_btn.click(
fn=quantize_model,
inputs=[
source_repo,
file_path,
quant_format,
model_preset,
exclude_layers,
full_precision_matmul,
target_repo,
target_path,
],
outputs=[output_file, status],
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())