Spaces:
Build error
Build error
| """ | |
| HuggingFace Space for model quantization using convert_to_quant. | |
| Provides a Gradio interface for quantizing safetensors models to various | |
| FP8/INT8 formats for ComfyUI inference, with HuggingFace Hub integration. | |
| """ | |
| import os | |
| import tempfile | |
| import gradio as gr | |
| import spaces # ZeroGPU - required for space to function | |
| from huggingface_hub import hf_hub_download, HfApi, create_commit, CommitOperationAdd | |
| from convert_to_quant import convert, ConversionConfig | |
| # HF Token from environment | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Model filter presets - simplified combined options | |
| # Maps display name to list of filter flags to enable | |
| MODEL_FILTER_PRESETS = { | |
| # Text Encoders | |
| "T5-XXL": { | |
| "description": "T5-XXL text encoder: skip norms/biases, remove decoder layers", | |
| "flags": ["t5xxl"], | |
| }, | |
| "Mistral": { | |
| "description": "Mistral text encoder exclusions", | |
| "flags": ["mistral"], | |
| }, | |
| "Visual Encoder": { | |
| "description": "Visual encoder: skip MLP layers (down/up/gate proj)", | |
| "flags": ["visual"], | |
| }, | |
| # Diffusion Models | |
| "Flux.2": { | |
| "description": "Flux.2: keep modulation/guidance/time/final layers high-precision", | |
| "flags": ["flux2"], | |
| }, | |
| "Chroma": { | |
| "description": "Chroma/distilled models: keep distilled_guidance, final, img/txt_in high-precision", | |
| "flags": ["distillation_large"], | |
| }, | |
| "Radiance": { | |
| "description": "Radiance/NeRF models: keep nerf_blocks, img_in_patch, nerf_final_layer high-precision", | |
| "flags": ["nerf_large", "radiance"], | |
| }, | |
| # Video Models | |
| "WAN Video": { | |
| "description": "WAN video model: skip embeddings, encoders, head", | |
| "flags": ["wan"], | |
| }, | |
| "Hunyuan Video": { | |
| "description": "Hunyuan Video 1.5: skip layernorm, attn norms, vision_in", | |
| "flags": ["hunyuan"], | |
| }, | |
| # Image Models | |
| "Qwen Image": { | |
| "description": "Qwen Image: skip added norms, keep time_text_embed high-precision", | |
| "flags": ["qwen"], | |
| }, | |
| "Z-Image": { | |
| "description": "Z-Image models: skip cap_embedder/norms, keep x_embedder/final/refiners high-precision", | |
| "flags": ["zimage", "zimage_refiner"], | |
| }, | |
| } | |
| # Quantization format options | |
| # NOTE: MXFP8/NVFP4 temporarily disabled - see TODO.md | |
| QUANT_FORMATS = { | |
| "FP8 Tensorwise": {"format": "fp8", "scaling_mode": "tensor", "block_size": None}, | |
| "FP8 Block (128)": {"format": "fp8", "scaling_mode": "block", "block_size": 128}, | |
| "FP8 Block (64)": {"format": "fp8", "scaling_mode": "block", "block_size": 64}, | |
| "INT8 Block (128)": {"format": "int8", "scaling_mode": None, "block_size": 128}, | |
| } | |
| def download_model_from_hub(source_repo: str, file_path: str) -> tuple[str | None, str]: | |
| """ | |
| Download a model file from HuggingFace Hub. | |
| Args: | |
| source_repo: Repository ID (username/repo_name) | |
| file_path: Path to file within the repository | |
| Returns: | |
| Tuple of (local_path, status_message) | |
| """ | |
| if not source_repo or not file_path: | |
| return None, "❌ Please provide both source repository and file path" | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=source_repo, | |
| filename=file_path, | |
| token=HF_TOKEN, | |
| local_dir=tempfile.gettempdir(), | |
| ) | |
| return local_path, f"✅ Downloaded: `{file_path}` from `{source_repo}`" | |
| except Exception as e: | |
| return None, f"❌ Download failed: {str(e)}" | |
| def upload_model_as_pr( | |
| local_path: str, | |
| target_repo: str, | |
| target_path: str, | |
| pr_title: str, | |
| ) -> str: | |
| """ | |
| Upload a model file to HuggingFace Hub as a Pull Request. | |
| Args: | |
| local_path: Path to local file | |
| target_repo: Target repository ID (username/repo_name) | |
| target_path: Path within the target repository | |
| pr_title: Title for the pull request | |
| Returns: | |
| Status message with PR URL | |
| """ | |
| if not local_path or not os.path.exists(local_path): | |
| return "❌ No file to upload" | |
| if not target_repo or not target_path: | |
| return "❌ Please provide target repository and path" | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| # Create commit as PR | |
| commit_info = api.create_commit( | |
| repo_id=target_repo, | |
| operations=[ | |
| CommitOperationAdd( | |
| path_in_repo=target_path, | |
| path_or_fileobj=local_path, | |
| ) | |
| ], | |
| commit_message=pr_title or "Add quantized model", | |
| create_pr=True, | |
| ) | |
| pr_url = commit_info.pr_url | |
| return f"✅ Pull Request created: [{pr_url}]({pr_url})" | |
| except Exception as e: | |
| return f"❌ Upload failed: {str(e)}" | |
| def completion_signal(): | |
| """Brief GPU allocation at completion to satisfy ZeroGPU requirements.""" | |
| return True | |
| def run_quantization(config): | |
| """Run quantization (CPU-based, no GPU timeout).""" | |
| return convert(config) | |
| def quantize_model( | |
| source_repo: str, | |
| file_path: str, | |
| quant_format: str, | |
| model_preset: str, | |
| exclude_layers_regex: str, | |
| full_precision_matmul: bool, | |
| target_repo: str, | |
| target_path: str, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Download, quantize, and optionally upload a model. | |
| Args: | |
| source_repo: Source HuggingFace repository (username/repo_name) | |
| file_path: Path to model file in source repo | |
| quant_format: Selected quantization format | |
| model_preset: Model preset filter (or "None") | |
| exclude_layers_regex: Regex pattern for layers to exclude | |
| full_precision_matmul: Enable full precision matrix multiplication | |
| target_repo: Target repository for upload (optional) | |
| target_path: Target path in repository (optional) | |
| Returns: | |
| Tuple of (output_file_path, status_message) | |
| """ | |
| status_log = [] | |
| # Step 1: Download model from Hub | |
| status_log.append("📥 **Downloading model...**") | |
| input_path, download_status = download_model_from_hub(source_repo, file_path) | |
| status_log.append(download_status) | |
| if input_path is None: | |
| return None, "\n\n".join(status_log) | |
| # Step 2: Get format settings | |
| format_config = QUANT_FORMATS.get(quant_format) | |
| if not format_config: | |
| status_log.append(f"❌ Unknown format: {quant_format}") | |
| return None, "\n\n".join(status_log) | |
| # Build filter flags from preset (can have multiple flags per preset) | |
| filter_flags = {} | |
| if model_preset and model_preset != "None": | |
| preset_config = MODEL_FILTER_PRESETS.get(model_preset) | |
| if preset_config: | |
| for flag in preset_config.get("flags", []): | |
| filter_flags[flag] = True | |
| # Generate output filename | |
| base_name = os.path.splitext(os.path.basename(input_path))[0] | |
| format_suffix = quant_format.lower().replace(" ", "_").replace("(", "").replace(")", "") | |
| output_name = f"{base_name}_{format_suffix}.safetensors" | |
| output_path = os.path.join(tempfile.gettempdir(), output_name) | |
| # Step 3: Build conversion config | |
| status_log.append(f"⚙️ **Quantizing with {quant_format}...**") | |
| config = ConversionConfig( | |
| input_path=input_path, | |
| output_path=output_path, | |
| quant_format=format_config["format"], | |
| comfy_quant=True, | |
| save_quant_metadata=True, | |
| simple=True, | |
| verbose="VERBOSE", | |
| scaling_mode=format_config.get("scaling_mode") or "tensor", | |
| block_size=format_config.get("block_size"), | |
| filter_flags=filter_flags, | |
| exclude_layers=exclude_layers_regex if exclude_layers_regex and exclude_layers_regex.strip() else None, | |
| full_precision_matrix_mult=full_precision_matmul, | |
| force_cpu=True, # Bypass CUDA checks on ZeroGPU | |
| ) | |
| try: | |
| result = run_quantization(config) | |
| if not result.success: | |
| status_log.append(f"❌ Quantization failed: {result.error}") | |
| return None, "\n\n".join(status_log) | |
| status_log.append(f"✅ Quantization complete: `{os.path.basename(result.output_path)}`") | |
| # Step 4: Upload to target repo if specified | |
| if target_repo and target_repo.strip(): | |
| status_log.append("📤 **Uploading as Pull Request...**") | |
| # Use provided target path or generate one | |
| upload_path = target_path.strip() if target_path and target_path.strip() else output_name | |
| pr_title = f"Add {quant_format} quantized model: {output_name}" | |
| upload_status = upload_model_as_pr( | |
| result.output_path, | |
| target_repo.strip(), | |
| upload_path, | |
| pr_title, | |
| ) | |
| status_log.append(upload_status) | |
| # Brief GPU allocation at completion to satisfy ZeroGPU | |
| completion_signal() | |
| return result.output_path, "\n\n".join(status_log) | |
| except Exception as e: | |
| status_log.append(f"❌ Error: {str(e)}") | |
| return None, "\n\n".join(status_log) | |
| # Build Gradio interface | |
| DESCRIPTION = """ | |
| # Model Quantization for ComfyUI | |
| Quantize safetensors models to FP8/INT8 formats for efficient inference in ComfyUI. | |
| ## Workflow | |
| 1. Enter source repository and file path to download model | |
| 2. Select quantization format and model preset | |
| 3. Optionally configure target repo to upload as Pull Request | |
| ## Quantization Formats | |
| - **FP8 Tensorwise**: Standard FP8 with per-tensor scaling (most compatible) | |
| - **FP8 Block**: FP8 with per-block scaling (better accuracy) | |
| - **MXFP8/NVFP4**: Next-gen formats (requires Blackwell GPU) | |
| - **INT8 Block**: INT8 with per-block scaling (Triton-based) | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📥 Source Model") | |
| source_repo = gr.Textbox( | |
| label="Source Repository", | |
| placeholder="username/model-repo", | |
| info="HuggingFace repository ID", | |
| ) | |
| file_path = gr.Textbox( | |
| label="File Path in Repository", | |
| placeholder="model.safetensors or path/to/model.safetensors", | |
| info="Path to the safetensors file within the repo", | |
| ) | |
| gr.Markdown("### ⚙️ Quantization Settings") | |
| quant_format = gr.Dropdown( | |
| label="Quantization Format", | |
| choices=list(QUANT_FORMATS.keys()), | |
| value="FP8 Tensorwise", | |
| ) | |
| model_preset = gr.Dropdown( | |
| label="Model Preset", | |
| choices=["None"] + list(MODEL_FILTER_PRESETS.keys()), | |
| value="None", | |
| info="Preset layer exclusions for specific architectures", | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| exclude_layers = gr.Textbox( | |
| label="Exclude Layers (Regex)", | |
| placeholder="e.g., img_in|txt_in|final_layer", | |
| info="Regex pattern for layers to keep in original precision", | |
| ) | |
| full_precision_matmul = gr.Checkbox( | |
| label="Full Precision Matrix Multiply", | |
| value=False, | |
| info="Use FP32 matmul (for storage-only quantization)", | |
| ) | |
| gr.Markdown("### 📤 Upload Target (Optional)") | |
| target_repo = gr.Textbox( | |
| label="Target Repository", | |
| placeholder="username/target-repo", | |
| info="Leave empty to skip upload", | |
| ) | |
| target_path = gr.Textbox( | |
| label="Target Path in Repository", | |
| placeholder="quantized/model_fp8.safetensors", | |
| info="Path where file will be uploaded (optional)", | |
| ) | |
| quantize_btn = gr.Button("🚀 Quantize Model", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_file = gr.File( | |
| label="Download Quantized Model", | |
| interactive=False, | |
| ) | |
| status = gr.Markdown(label="Status") | |
| # Show preset description when selected | |
| preset_info = gr.Markdown("") | |
| # Update preset info when selection changes | |
| def update_preset_info(preset): | |
| if preset and preset != "None": | |
| preset_config = MODEL_FILTER_PRESETS.get(preset, {}) | |
| desc = preset_config.get("description", "") | |
| return f"**{preset}**: {desc}" | |
| return "" | |
| model_preset.change( | |
| fn=update_preset_info, | |
| inputs=[model_preset], | |
| outputs=[preset_info], | |
| ) | |
| # Run quantization | |
| quantize_btn.click( | |
| fn=quantize_model, | |
| inputs=[ | |
| source_repo, | |
| file_path, | |
| quant_format, | |
| model_preset, | |
| exclude_layers, | |
| full_precision_matmul, | |
| target_repo, | |
| target_path, | |
| ], | |
| outputs=[output_file, status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |