Spaces:
Sleeping
Sleeping
| title: Model Optimization Lab | |
| emoji: 😻 | |
| colorFrom: pink | |
| colorTo: gray | |
| sdk: gradio | |
| sdk_version: 6.0.0 | |
| app_file: app.py | |
| pinned: false | |
| # Model Optimization Lab | |
| Interactive Gradio playground for comparing pruning and quantization on ImageNet classification, ADE20K segmentation, and COCO detection models (TorchVision + YOLO12). Upload any image and observe how latency, confidence, model size, and segmentation/detection quality change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation. | |
| ## Features | |
| - **Classification Tasks**: Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0). | |
| - **Segmentation Tasks**: Pretrained ADE20K models (SegFormer B0/B4, DPT Large, UPerNet ConvNeXt-Tiny) with 150-class semantic segmentation. | |
| - **Detection Tasks**: COCO-pretrained detectors (TorchVision Faster R-CNN/SSDlite) plus Ultralytics YOLO12 n/s/m/l/x. | |
| - **Pruning tabs**: Structured/unstructured pruning with configurable sparsity and comprehensive size/latency comparison across tasks. | |
| - **Quantization tabs**: Dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels, available for all tasks. | |
| - **Visual Comparisons**: | |
| - Classification: Automated metric tables and Top-5 bar charts to visualize confidence shifts. | |
| - Segmentation: Image sliders for overlay/mask comparisons, class distribution tables, and mask agreement metrics. | |
| - Detection: Overlay sliders for pruned/quantized boxes and detection tables for quick inspection. | |
| - **Export Options**: TorchScript, ONNX, JSON reports, and state dictionaries for all optimization variants. | |
| - Lightweight CLI mode for quick experiments without launching the UI. | |
| ## Requirements | |
| - Python 3.9+ | |
| - PyTorch with CPU support (GPU optional but recommended for FP16 experiments). | |
| - The packages listed in `requirements.txt`: | |
| - `torch`, `torchvision` - Core PyTorch framework | |
| - `timm` - Classification model architectures | |
| - `segmentation-models-pytorch` - Segmentation model architectures | |
| - `albumentations` - Image preprocessing for segmentation models | |
| - `gradio` - Web UI framework | |
| - `pandas`, `matplotlib`, `numpy`, `pillow` - Data processing and visualization | |
| ## Quick Start | |
| 1. Clone the repository: | |
| ```bash | |
| git clone https://github.com/shriarul5273/model-optimization-lab.git | |
| cd model-optimization-lab | |
| ``` | |
| 2. Create and activate a virtual environment (optional but recommended). | |
| 3. Install dependencies: | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| 4. Launch the Gradio app: | |
| ```bash | |
| python app.py | |
| ``` | |
| 5. Open the local Gradio URL (printed in the terminal) in your browser. | |
| ## Using the App | |
| 1. **Upload an image** or pick one of the provided examples (ImageNet samples for classification, ADE20K validation images for segmentation; detection works with any RGB image). | |
| 2. Choose the **Base Model** dropdown: | |
| - **Classification**: ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0 | |
| - **Segmentation**: SegFormer B0/B4 (ADE20K 512x512), DPT Large (ADE20K), UPerNet ConvNeXt-Tiny (ADE20K) | |
| - **Detection**: Faster R-CNN ResNet50 FPN (COCO), SSDlite320 MobileNetV3 (COCO), YOLO12 n/s/m/l/x (COCO via Ultralytics) | |
| 3. Pick a **Hardware Preset** or keep `custom`: | |
| - Edge CPU — CPU, channels-last off, dynamic quantization, 30% pruning. | |
| - Datacenter GPU — CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning. | |
| - Apple MPS — MPS, FP16 quantization, 20% pruning. | |
| 4. Select a tab (Pruning/Quantization for Classification, Detection, or Segmentation), configure options, then click **Run**. | |
| ### Pruning tab options (Classification & Segmentation) | |
| - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export. | |
| - `Pruning Amount`: 0.1–0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support. | |
| - `Device`: `auto` picks CUDA → MPS → CPU. Channels-last is only honored on CUDA. | |
| - `Channels-last input (CUDA)`: Converts tensors to channels-last for better CUDA kernel throughput. | |
| - `Mixed precision (AMP)`: Enables CUDA autocast for FP16/FP32 mixes. | |
| - `Torch compile (PyTorch 2)`: Wraps the pruned model in `torch.compile` when available. | |
| - **Exports**: | |
| - Classification: `pruned_model.ts`, `pruned_model.onnx`, `pruned_report.json`, `pruned_state_dict.pth` | |
| - Segmentation: `pruned_seg_model.ts`, `pruned_seg_model.onnx`, `pruned_seg_report.json`, `pruned_seg_state_dict.pth` | |
| - **Outputs**: | |
| - Classification: Comparison metrics, Top-5 bar chart, per-layer sparsity table, download list | |
| - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, per-layer sparsity table, download list | |
| ### Detection tab options (Pruning & Quantization) | |
| - `Models`: TorchVision Faster R-CNN / SSDlite, plus Ultralytics YOLO12 n/s/m/l/x (auto-downloads checkpoints if missing). | |
| - `Score Threshold`: Filters low-confidence boxes before metrics/overlays. | |
| - `Pruning`: Structured recommended for detection heads; unstructured yields higher sparsity but fewer real speedups. | |
| - `Quantization`: Dynamic/weight-only INT8 forces CPU for kernel support; FP16 targets CUDA/MPS. AMP + channels-last help on GPU. | |
| - `Exports`: State dicts always saved. TorchScript/ONNX exports remain enabled for TorchVision detectors; YOLO12 exports are skipped (TorchScript/ONNX) but state dict is still written. | |
| ### Quantization tab options (Classification & Segmentation) | |
| - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision). | |
| - `Device`: `auto` picks CUDA → MPS → CPU; dynamic/weight-only runs force CPU execution for kernel support. | |
| - `Channels-last input (CUDA)`: Same as pruning; ignored on CPU. | |
| - `Mixed precision (AMP)`: Applies CUDA autocast to the quantized forward pass. | |
| - `Torch compile (PyTorch 2)`: Compiles the quantized model when available. | |
| - **Exports**: | |
| - Classification: `quantized_model.ts`, `quantized_model.onnx`, `quant_report.json`, `quantized_state_dict.pth` | |
| - Segmentation: `quant_seg_model.ts`, `quant_seg_model.onnx`, `quant_seg_report.json`, `quant_seg_state_dict.pth` | |
| - **Outputs**: | |
| - Classification: Comparison metrics, Top-5 bar chart, download list | |
| - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, download list | |
| ### What gets exported | |
| - Artifacts are written to `exports/`. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants. | |
| - TorchScript/ONNX exports run best on CPU inputs; failures are logged to the console and skipped. | |
| - State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere. | |
| ### Output Interpreting Tips | |
| - **Top-1 Prediction (Classification)**: Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g., `chambered nautilus, pearly nautilus`). | |
| - **Mask Agreement (Segmentation)**: Percentage of pixels where original and optimized models predict the same class. 100% means identical masks; lower values indicate divergence. | |
| - **Class Distribution (Segmentation)**: Shows the top 25 most prevalent classes by pixel coverage, with percentages and counts for both models. | |
| - **Image Sliders (Segmentation)**: Drag the slider to compare original vs. optimized overlays or raw masks side-by-side. | |
| - **Latency (ms)**: Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model execution—see [Performance Notes](#performance-notes). | |
| - **Model Size (MB)**: Serialized state dictionary size after saving to disk. | |
| ## Performance Notes | |
| - The current quantization pipeline rebuilds the optimized model on each request, so the reported latency includes setup time. Reusing pre-quantized instances will yield more realistic numbers. | |
| - Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization. | |
| - PyTorch default quantization backend may fall back to `qnnpack` on CPU. For x86 systems, set `torch.backends.quantized.engine = "fbgemm"` before quantization for best results. | |
| - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead. | |
| - Detection-specific: Dynamic/weight-only runs force CPU for kernel support; YOLO12 checkpoints auto-download but TorchScript/ONNX exports are disabled (state dicts still save). | |
| ## Extending the Lab | |
| - **Classification**: Swap in different architectures by changing the `timm.create_model` call in `app.py`. | |
| - **Segmentation**: Add new models from the [smp-hub](https://huggingface.co/smp-hub) collection by adding entries to `SEGMENTATION_MODEL_CONFIGS`. | |
| - Add calibration data and static INT8 quantization to include convolution layers. | |
| - Cache optimized models to avoid recomputation across requests. | |
| - Integrate evaluation datasets to quantify accuracy drop (classification: top-1/top-5, segmentation: mIoU, pixel accuracy). | |
| ## CLI Mode | |
| - Run without the UI: `python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto` | |
| - Modes: `--mode prune` (structured pruning @ 0.4 sparsity) or `--mode quant` (dynamic quantization). Both emit the metrics table and export artifacts list. | |
| - Devices: `auto` chooses CUDA → MPS → CPU based on availability; `cpu`/`cuda`/`mps` force a device. Dynamic/weight-only quantization forces CPU for kernel support even if GPU is requested. | |
| - Models: any entry from `MODEL_OPTIONS` in `app.py`. | |
| ## Troubleshooting | |
| - **Slow downloads**: The first run downloads pretrained weights (~100 MB). Subsequent runs use cached files. | |
| - **CUDA errors**: Ensure the correct CUDA-enabled PyTorch build is installed if you intend to run on GPU. | |
| - **Quantized model larger than expected**: The state dictionary includes dequantized tensors for some paths (e.g., dynamic quantization). Consider TorchScript or ONNX export for compact deployment artifacts. | |
| ## License | |
| This project inherits the default license of the repository. Replace or update this section if you add a specific license. | |