""" HuggingFace Space for model quantization using convert_to_quant. Provides a Gradio interface for quantizing safetensors models to various FP8/INT8 formats for ComfyUI inference, with HuggingFace Hub integration. """ import os import tempfile import gradio as gr import spaces # ZeroGPU - required for space to function from huggingface_hub import hf_hub_download, HfApi, create_commit, CommitOperationAdd from convert_to_quant import convert, ConversionConfig # HF Token from environment HF_TOKEN = os.environ.get("HF_TOKEN") # Model filter presets - simplified combined options # Maps display name to list of filter flags to enable MODEL_FILTER_PRESETS = { # Text Encoders "T5-XXL": { "description": "T5-XXL text encoder: skip norms/biases, remove decoder layers", "flags": ["t5xxl"], }, "Mistral": { "description": "Mistral text encoder exclusions", "flags": ["mistral"], }, "Visual Encoder": { "description": "Visual encoder: skip MLP layers (down/up/gate proj)", "flags": ["visual"], }, # Diffusion Models "Flux.2": { "description": "Flux.2: keep modulation/guidance/time/final layers high-precision", "flags": ["flux2"], }, "Chroma": { "description": "Chroma/distilled models: keep distilled_guidance, final, img/txt_in high-precision", "flags": ["distillation_large"], }, "Radiance": { "description": "Radiance/NeRF models: keep nerf_blocks, img_in_patch, nerf_final_layer high-precision", "flags": ["nerf_large", "radiance"], }, # Video Models "WAN Video": { "description": "WAN video model: skip embeddings, encoders, head", "flags": ["wan"], }, "Hunyuan Video": { "description": "Hunyuan Video 1.5: skip layernorm, attn norms, vision_in", "flags": ["hunyuan"], }, # Image Models "Qwen Image": { "description": "Qwen Image: skip added norms, keep time_text_embed high-precision", "flags": ["qwen"], }, "Z-Image": { "description": "Z-Image models: skip cap_embedder/norms, keep x_embedder/final/refiners high-precision", "flags": ["zimage", "zimage_refiner"], }, } # Quantization format options # NOTE: MXFP8/NVFP4 temporarily disabled - see TODO.md QUANT_FORMATS = { "FP8 Tensorwise": {"format": "fp8", "scaling_mode": "tensor", "block_size": None}, "FP8 Block (128)": {"format": "fp8", "scaling_mode": "block", "block_size": 128}, "FP8 Block (64)": {"format": "fp8", "scaling_mode": "block", "block_size": 64}, "INT8 Block (128)": {"format": "int8", "scaling_mode": None, "block_size": 128}, } def download_model_from_hub(source_repo: str, file_path: str) -> tuple[str | None, str]: """ Download a model file from HuggingFace Hub. Args: source_repo: Repository ID (username/repo_name) file_path: Path to file within the repository Returns: Tuple of (local_path, status_message) """ if not source_repo or not file_path: return None, "❌ Please provide both source repository and file path" try: local_path = hf_hub_download( repo_id=source_repo, filename=file_path, token=HF_TOKEN, local_dir=tempfile.gettempdir(), ) return local_path, f"✅ Downloaded: `{file_path}` from `{source_repo}`" except Exception as e: return None, f"❌ Download failed: {str(e)}" def upload_model_as_pr( local_path: str, target_repo: str, target_path: str, pr_title: str, ) -> str: """ Upload a model file to HuggingFace Hub as a Pull Request. Args: local_path: Path to local file target_repo: Target repository ID (username/repo_name) target_path: Path within the target repository pr_title: Title for the pull request Returns: Status message with PR URL """ if not local_path or not os.path.exists(local_path): return "❌ No file to upload" if not target_repo or not target_path: return "❌ Please provide target repository and path" try: api = HfApi(token=HF_TOKEN) # Create commit as PR commit_info = api.create_commit( repo_id=target_repo, operations=[ CommitOperationAdd( path_in_repo=target_path, path_or_fileobj=local_path, ) ], commit_message=pr_title or "Add quantized model", create_pr=True, ) pr_url = commit_info.pr_url return f"✅ Pull Request created: [{pr_url}]({pr_url})" except Exception as e: return f"❌ Upload failed: {str(e)}" @spaces.GPU(duration=30) def completion_signal(): """Brief GPU allocation at completion to satisfy ZeroGPU requirements.""" return True def run_quantization(config): """Run quantization (CPU-based, no GPU timeout).""" return convert(config) def quantize_model( source_repo: str, file_path: str, quant_format: str, model_preset: str, exclude_layers_regex: str, full_precision_matmul: bool, target_repo: str, target_path: str, progress=gr.Progress(track_tqdm=True), ): """ Download, quantize, and optionally upload a model. Args: source_repo: Source HuggingFace repository (username/repo_name) file_path: Path to model file in source repo quant_format: Selected quantization format model_preset: Model preset filter (or "None") exclude_layers_regex: Regex pattern for layers to exclude full_precision_matmul: Enable full precision matrix multiplication target_repo: Target repository for upload (optional) target_path: Target path in repository (optional) Returns: Tuple of (output_file_path, status_message) """ status_log = [] # Step 1: Download model from Hub status_log.append("📥 **Downloading model...**") input_path, download_status = download_model_from_hub(source_repo, file_path) status_log.append(download_status) if input_path is None: return None, "\n\n".join(status_log) # Step 2: Get format settings format_config = QUANT_FORMATS.get(quant_format) if not format_config: status_log.append(f"❌ Unknown format: {quant_format}") return None, "\n\n".join(status_log) # Build filter flags from preset (can have multiple flags per preset) filter_flags = {} if model_preset and model_preset != "None": preset_config = MODEL_FILTER_PRESETS.get(model_preset) if preset_config: for flag in preset_config.get("flags", []): filter_flags[flag] = True # Generate output filename base_name = os.path.splitext(os.path.basename(input_path))[0] format_suffix = quant_format.lower().replace(" ", "_").replace("(", "").replace(")", "") output_name = f"{base_name}_{format_suffix}.safetensors" output_path = os.path.join(tempfile.gettempdir(), output_name) # Step 3: Build conversion config status_log.append(f"⚙️ **Quantizing with {quant_format}...**") config = ConversionConfig( input_path=input_path, output_path=output_path, quant_format=format_config["format"], comfy_quant=True, save_quant_metadata=True, simple=True, verbose="VERBOSE", scaling_mode=format_config.get("scaling_mode") or "tensor", block_size=format_config.get("block_size"), filter_flags=filter_flags, exclude_layers=exclude_layers_regex if exclude_layers_regex and exclude_layers_regex.strip() else None, full_precision_matrix_mult=full_precision_matmul, force_cpu=True, # Bypass CUDA checks on ZeroGPU ) try: result = run_quantization(config) if not result.success: status_log.append(f"❌ Quantization failed: {result.error}") return None, "\n\n".join(status_log) status_log.append(f"✅ Quantization complete: `{os.path.basename(result.output_path)}`") # Step 4: Upload to target repo if specified if target_repo and target_repo.strip(): status_log.append("📤 **Uploading as Pull Request...**") # Use provided target path or generate one upload_path = target_path.strip() if target_path and target_path.strip() else output_name pr_title = f"Add {quant_format} quantized model: {output_name}" upload_status = upload_model_as_pr( result.output_path, target_repo.strip(), upload_path, pr_title, ) status_log.append(upload_status) # Brief GPU allocation at completion to satisfy ZeroGPU completion_signal() return result.output_path, "\n\n".join(status_log) except Exception as e: status_log.append(f"❌ Error: {str(e)}") return None, "\n\n".join(status_log) # Build Gradio interface DESCRIPTION = """ # Model Quantization for ComfyUI Quantize safetensors models to FP8/INT8 formats for efficient inference in ComfyUI. ## Workflow 1. Enter source repository and file path to download model 2. Select quantization format and model preset 3. Optionally configure target repo to upload as Pull Request ## Quantization Formats - **FP8 Tensorwise**: Standard FP8 with per-tensor scaling (most compatible) - **FP8 Block**: FP8 with per-block scaling (better accuracy) - **MXFP8/NVFP4**: Next-gen formats (requires Blackwell GPU) - **INT8 Block**: INT8 with per-block scaling (Triton-based) """ with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### 📥 Source Model") source_repo = gr.Textbox( label="Source Repository", placeholder="username/model-repo", info="HuggingFace repository ID", ) file_path = gr.Textbox( label="File Path in Repository", placeholder="model.safetensors or path/to/model.safetensors", info="Path to the safetensors file within the repo", ) gr.Markdown("### ⚙️ Quantization Settings") quant_format = gr.Dropdown( label="Quantization Format", choices=list(QUANT_FORMATS.keys()), value="FP8 Tensorwise", ) model_preset = gr.Dropdown( label="Model Preset", choices=["None"] + list(MODEL_FILTER_PRESETS.keys()), value="None", info="Preset layer exclusions for specific architectures", ) with gr.Accordion("Advanced Options", open=False): exclude_layers = gr.Textbox( label="Exclude Layers (Regex)", placeholder="e.g., img_in|txt_in|final_layer", info="Regex pattern for layers to keep in original precision", ) full_precision_matmul = gr.Checkbox( label="Full Precision Matrix Multiply", value=False, info="Use FP32 matmul (for storage-only quantization)", ) gr.Markdown("### 📤 Upload Target (Optional)") target_repo = gr.Textbox( label="Target Repository", placeholder="username/target-repo", info="Leave empty to skip upload", ) target_path = gr.Textbox( label="Target Path in Repository", placeholder="quantized/model_fp8.safetensors", info="Path where file will be uploaded (optional)", ) quantize_btn = gr.Button("🚀 Quantize Model", variant="primary", size="lg") with gr.Column(scale=1): output_file = gr.File( label="Download Quantized Model", interactive=False, ) status = gr.Markdown(label="Status") # Show preset description when selected preset_info = gr.Markdown("") # Update preset info when selection changes def update_preset_info(preset): if preset and preset != "None": preset_config = MODEL_FILTER_PRESETS.get(preset, {}) desc = preset_config.get("description", "") return f"**{preset}**: {desc}" return "" model_preset.change( fn=update_preset_info, inputs=[model_preset], outputs=[preset_info], ) # Run quantization quantize_btn.click( fn=quantize_model, inputs=[ source_repo, file_path, quant_format, model_preset, exclude_layers, full_precision_matmul, target_repo, target_path, ], outputs=[output_file, status], ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())