Spaces:
Sleeping
Sleeping
A newer version of the Gradio SDK is available:
6.1.0
metadata
title: Model Optimization Lab
emoji: 😻
colorFrom: pink
colorTo: gray
sdk: gradio
sdk_version: 6.0.0
app_file: app.py
pinned: false
Model Optimization Lab
Interactive Gradio playground for comparing pruning and quantization on ImageNet classification, ADE20K segmentation, and COCO detection models (TorchVision + YOLO12). Upload any image and observe how latency, confidence, model size, and segmentation/detection quality change when applying different compression recipes. Pretrained weights are loaded by default; set MODEL_OPT_PRETRAINED=0 if you want random initialization for experimentation.
Features
- Classification Tasks: Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
- Segmentation Tasks: Pretrained ADE20K models (SegFormer B0/B4, DPT Large, UPerNet ConvNeXt-Tiny) with 150-class semantic segmentation.
- Detection Tasks: COCO-pretrained detectors (TorchVision Faster R-CNN/SSDlite) plus Ultralytics YOLO12 n/s/m/l/x.
- Pruning tabs: Structured/unstructured pruning with configurable sparsity and comprehensive size/latency comparison across tasks.
- Quantization tabs: Dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels, available for all tasks.
- Visual Comparisons:
- Classification: Automated metric tables and Top-5 bar charts to visualize confidence shifts.
- Segmentation: Image sliders for overlay/mask comparisons, class distribution tables, and mask agreement metrics.
- Detection: Overlay sliders for pruned/quantized boxes and detection tables for quick inspection.
- Export Options: TorchScript, ONNX, JSON reports, and state dictionaries for all optimization variants.
- Lightweight CLI mode for quick experiments without launching the UI.
Requirements
- Python 3.9+
- PyTorch with CPU support (GPU optional but recommended for FP16 experiments).
- The packages listed in
requirements.txt:torch,torchvision- Core PyTorch frameworktimm- Classification model architecturessegmentation-models-pytorch- Segmentation model architecturesalbumentations- Image preprocessing for segmentation modelsgradio- Web UI frameworkpandas,matplotlib,numpy,pillow- Data processing and visualization
Quick Start
- Clone the repository:
git clone https://github.com/shriarul5273/model-optimization-lab.git cd model-optimization-lab - Create and activate a virtual environment (optional but recommended).
- Install dependencies:
pip install -r requirements.txt - Launch the Gradio app:
python app.py - Open the local Gradio URL (printed in the terminal) in your browser.
Using the App
- Upload an image or pick one of the provided examples (ImageNet samples for classification, ADE20K validation images for segmentation; detection works with any RGB image).
- Choose the Base Model dropdown:
- Classification: ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0
- Segmentation: SegFormer B0/B4 (ADE20K 512x512), DPT Large (ADE20K), UPerNet ConvNeXt-Tiny (ADE20K)
- Detection: Faster R-CNN ResNet50 FPN (COCO), SSDlite320 MobileNetV3 (COCO), YOLO12 n/s/m/l/x (COCO via Ultralytics)
- Pick a Hardware Preset or keep
custom:- Edge CPU — CPU, channels-last off, dynamic quantization, 30% pruning.
- Datacenter GPU — CUDA, channels-last on,
torch.compile, FP16 quantization, 20% pruning. - Apple MPS — MPS, FP16 quantization, 20% pruning.
- Select a tab (Pruning/Quantization for Classification, Detection, or Segmentation), configure options, then click Run.
Pruning tab options (Classification & Segmentation)
Pruning Method:structured(LN-structured) orunstructured(L1). Applied to Conv2d weights before export.Pruning Amount: 0.1–0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support.Device:autopicks CUDA → MPS → CPU. Channels-last is only honored on CUDA.Channels-last input (CUDA): Converts tensors to channels-last for better CUDA kernel throughput.Mixed precision (AMP): Enables CUDA autocast for FP16/FP32 mixes.Torch compile (PyTorch 2): Wraps the pruned model intorch.compilewhen available.- Exports:
- Classification:
pruned_model.ts,pruned_model.onnx,pruned_report.json,pruned_state_dict.pth - Segmentation:
pruned_seg_model.ts,pruned_seg_model.onnx,pruned_seg_report.json,pruned_seg_state_dict.pth
- Classification:
- Outputs:
- Classification: Comparison metrics, Top-5 bar chart, per-layer sparsity table, download list
- Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, per-layer sparsity table, download list
Detection tab options (Pruning & Quantization)
Models: TorchVision Faster R-CNN / SSDlite, plus Ultralytics YOLO12 n/s/m/l/x (auto-downloads checkpoints if missing).Score Threshold: Filters low-confidence boxes before metrics/overlays.Pruning: Structured recommended for detection heads; unstructured yields higher sparsity but fewer real speedups.Quantization: Dynamic/weight-only INT8 forces CPU for kernel support; FP16 targets CUDA/MPS. AMP + channels-last help on GPU.Exports: State dicts always saved. TorchScript/ONNX exports remain enabled for TorchVision detectors; YOLO12 exports are skipped (TorchScript/ONNX) but state dict is still written.
Quantization tab options (Classification & Segmentation)
Quantization Type:dynamic/weight_only(INT8 linear layers on CPU), orfp16(casts model to half precision).Device:autopicks CUDA → MPS → CPU; dynamic/weight-only runs force CPU execution for kernel support.Channels-last input (CUDA): Same as pruning; ignored on CPU.Mixed precision (AMP): Applies CUDA autocast to the quantized forward pass.Torch compile (PyTorch 2): Compiles the quantized model when available.- Exports:
- Classification:
quantized_model.ts,quantized_model.onnx,quant_report.json,quantized_state_dict.pth - Segmentation:
quant_seg_model.ts,quant_seg_model.onnx,quant_seg_report.json,quant_seg_state_dict.pth
- Classification:
- Outputs:
- Classification: Comparison metrics, Top-5 bar chart, download list
- Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, download list
What gets exported
- Artifacts are written to
exports/. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants. - TorchScript/ONNX exports run best on CPU inputs; failures are logged to the console and skipped.
- State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere.
Output Interpreting Tips
- Top-1 Prediction (Classification): Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g.,
chambered nautilus, pearly nautilus). - Mask Agreement (Segmentation): Percentage of pixels where original and optimized models predict the same class. 100% means identical masks; lower values indicate divergence.
- Class Distribution (Segmentation): Shows the top 25 most prevalent classes by pixel coverage, with percentages and counts for both models.
- Image Sliders (Segmentation): Drag the slider to compare original vs. optimized overlays or raw masks side-by-side.
- Latency (ms): Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model execution—see Performance Notes.
- Model Size (MB): Serialized state dictionary size after saving to disk.
Performance Notes
- The current quantization pipeline rebuilds the optimized model on each request, so the reported latency includes setup time. Reusing pre-quantized instances will yield more realistic numbers.
- Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization.
- PyTorch default quantization backend may fall back to
qnnpackon CPU. For x86 systems, settorch.backends.quantized.engine = "fbgemm"before quantization for best results. - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
- Detection-specific: Dynamic/weight-only runs force CPU for kernel support; YOLO12 checkpoints auto-download but TorchScript/ONNX exports are disabled (state dicts still save).
Extending the Lab
- Classification: Swap in different architectures by changing the
timm.create_modelcall inapp.py. - Segmentation: Add new models from the smp-hub collection by adding entries to
SEGMENTATION_MODEL_CONFIGS. - Add calibration data and static INT8 quantization to include convolution layers.
- Cache optimized models to avoid recomputation across requests.
- Integrate evaluation datasets to quantify accuracy drop (classification: top-1/top-5, segmentation: mIoU, pixel accuracy).
CLI Mode
- Run without the UI:
python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto - Modes:
--mode prune(structured pruning @ 0.4 sparsity) or--mode quant(dynamic quantization). Both emit the metrics table and export artifacts list. - Devices:
autochooses CUDA → MPS → CPU based on availability;cpu/cuda/mpsforce a device. Dynamic/weight-only quantization forces CPU for kernel support even if GPU is requested. - Models: any entry from
MODEL_OPTIONSinapp.py.
Troubleshooting
- Slow downloads: The first run downloads pretrained weights (~100 MB). Subsequent runs use cached files.
- CUDA errors: Ensure the correct CUDA-enabled PyTorch build is installed if you intend to run on GPU.
- Quantized model larger than expected: The state dictionary includes dequantized tensors for some paths (e.g., dynamic quantization). Consider TorchScript or ONNX export for compact deployment artifacts.
License
This project inherits the default license of the repository. Replace or update this section if you add a specific license.