shriarul5273 commited on
Commit
1e97e9e
Β·
1 Parent(s): 34a90ad

Added segmentation models for pruning and quantization

Browse files
README.md CHANGED
@@ -11,19 +11,29 @@ pinned: false
11
 
12
  # Model Optimization Lab
13
 
14
- Interactive Gradio playground for comparing pruning and quantization on ImageNet-classification backbones. Upload any image and observe how latency, confidence, and model size change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
15
 
16
  ## Features
17
- - Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, etc.).
18
- - Pruning tab: structured/unstructured pruning with configurable sparsity and size/latency comparison.
19
- - Quantization tab: dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels.
20
- - Automated metric tables and Top-5 bar charts to visualize confidence shifts between optimized variants.
 
 
 
 
21
  - Lightweight CLI mode for quick experiments without launching the UI.
22
 
23
  ## Requirements
24
  - Python 3.9+
25
  - PyTorch with CPU support (GPU optional but recommended for FP16 experiments).
26
- - The packages listed in `requirements.txt` or installed via `pip install -r requirements.txt` (create the file if missing with entries like `torch`, `timm`, `gradio`, `pandas`, `torchvision`).
 
 
 
 
 
 
27
 
28
  ## Quick Start
29
  1. Clone the repository:
@@ -43,32 +53,42 @@ Interactive Gradio playground for comparing pruning and quantization on ImageNet
43
  5. Open the local Gradio URL (printed in the terminal) in your browser.
44
 
45
  ## Using the App
46
- 1. **Upload an image** or pick one of the provided examples.
47
- 2. Choose the **Base Model** dropdown (ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
 
 
48
  3. Pick a **Hardware Preset** or keep `custom`:
49
  - Edge CPU β€” CPU, channels-last off, dynamic quantization, 30% pruning.
50
  - Datacenter GPU β€” CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
51
  - Apple MPS β€” MPS, FP16 quantization, 20% pruning.
52
- 4. Pick a tab and set options, then click **Run**.
53
 
54
- ### Pruning tab options
55
  - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
56
  - `Pruning Amount`: 0.1–0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support.
57
  - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU. Channels-last is only honored on CUDA.
58
  - `Channels-last input (CUDA)`: Converts tensors to channels-last for better CUDA kernel throughput.
59
  - `Mixed precision (AMP)`: Enables CUDA autocast for FP16/FP32 mixes.
60
  - `Torch compile (PyTorch 2)`: Wraps the pruned model in `torch.compile` when available.
61
- - Exports: TorchScript (`pruned_model.ts`), ONNX (`pruned_model.onnx`), JSON report, always saves `pruned_state_dict.pth`.
62
- - Outputs: comparison metrics, Top-5 bar chart, per-layer sparsity table, download list of artifacts.
 
 
 
 
63
 
64
- ### Quantization tab options
65
  - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
66
  - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU; dynamic/weight-only runs force CPU execution for kernel support.
67
  - `Channels-last input (CUDA)`: Same as pruning; ignored on CPU.
68
  - `Mixed precision (AMP)`: Applies CUDA autocast to the quantized forward pass.
69
  - `Torch compile (PyTorch 2)`: Compiles the quantized model when available.
70
- - Exports: TorchScript (`quantized_model.ts`), ONNX (`quantized_model.onnx`), JSON report, always saves `quantized_state_dict.pth`.
71
- - Outputs: comparison metrics, Top-5 bar chart, download list of artifacts.
 
 
 
 
72
 
73
  ### What gets exported
74
  - Artifacts are written to `exports/`. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants.
@@ -76,7 +96,10 @@ Interactive Gradio playground for comparing pruning and quantization on ImageNet
76
  - State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere.
77
 
78
  ### Output Interpreting Tips
79
- - **Top-1 Prediction**: Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g., `chambered nautilus, pearly nautilus`).
 
 
 
80
  - **Latency (ms)**: Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model executionβ€”see [Performance Notes](#performance-notes).
81
  - **Model Size (MB)**: Serialized state dictionary size after saving to disk.
82
 
@@ -87,10 +110,11 @@ Interactive Gradio playground for comparing pruning and quantization on ImageNet
87
  - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
88
 
89
  ## Extending the Lab
90
- - Swap in different architectures by changing the `timm.create_model` call in `app.py`.
 
91
  - Add calibration data and static INT8 quantization to include convolution layers.
92
  - Cache optimized models to avoid recomputation across requests.
93
- - Integrate evaluation datasets to quantify accuracy drop beyond top-1 confidence.
94
 
95
  ## CLI Mode
96
  - Run without the UI: `python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto`
 
11
 
12
  # Model Optimization Lab
13
 
14
+ Interactive Gradio playground for comparing pruning and quantization on both ImageNet-classification and ADE20K-segmentation models. Upload any image and observe how latency, confidence, model size, and segmentation quality change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
15
 
16
  ## Features
17
+ - **Classification Tasks**: Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
18
+ - **Segmentation Tasks**: Pretrained ADE20K models (SegFormer B0/B4, DPT Large, UPerNet ConvNeXt-Tiny) with 150-class semantic segmentation.
19
+ - **Pruning tabs**: Structured/unstructured pruning with configurable sparsity and comprehensive size/latency comparison for both classification and segmentation.
20
+ - **Quantization tabs**: Dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels, available for both task types.
21
+ - **Visual Comparisons**:
22
+ - Classification: Automated metric tables and Top-5 bar charts to visualize confidence shifts.
23
+ - Segmentation: Image sliders for overlay/mask comparisons, class distribution tables, and mask agreement metrics.
24
+ - **Export Options**: TorchScript, ONNX, JSON reports, and state dictionaries for all optimization variants.
25
  - Lightweight CLI mode for quick experiments without launching the UI.
26
 
27
  ## Requirements
28
  - Python 3.9+
29
  - PyTorch with CPU support (GPU optional but recommended for FP16 experiments).
30
+ - The packages listed in `requirements.txt`:
31
+ - `torch`, `torchvision` - Core PyTorch framework
32
+ - `timm` - Classification model architectures
33
+ - `segmentation-models-pytorch` - Segmentation model architectures
34
+ - `albumentations` - Image preprocessing for segmentation models
35
+ - `gradio` - Web UI framework
36
+ - `pandas`, `matplotlib`, `numpy`, `pillow` - Data processing and visualization
37
 
38
  ## Quick Start
39
  1. Clone the repository:
 
53
  5. Open the local Gradio URL (printed in the terminal) in your browser.
54
 
55
  ## Using the App
56
+ 1. **Upload an image** or pick one of the provided examples (ImageNet samples for classification, ADE20K validation images for segmentation).
57
+ 2. Choose the **Base Model** dropdown:
58
+ - **Classification**: ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0
59
+ - **Segmentation**: SegFormer B0/B4 (ADE20K 512x512), DPT Large (ADE20K), UPerNet ConvNeXt-Tiny (ADE20K)
60
  3. Pick a **Hardware Preset** or keep `custom`:
61
  - Edge CPU β€” CPU, channels-last off, dynamic quantization, 30% pruning.
62
  - Datacenter GPU β€” CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
63
  - Apple MPS β€” MPS, FP16 quantization, 20% pruning.
64
+ 4. Select a tab (Pruning-Classification, Quantization-Classification, Pruning-Segmentation, or Quantization-Segmentation), configure options, then click **Run**.
65
 
66
+ ### Pruning tab options (Classification & Segmentation)
67
  - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
68
  - `Pruning Amount`: 0.1–0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support.
69
  - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU. Channels-last is only honored on CUDA.
70
  - `Channels-last input (CUDA)`: Converts tensors to channels-last for better CUDA kernel throughput.
71
  - `Mixed precision (AMP)`: Enables CUDA autocast for FP16/FP32 mixes.
72
  - `Torch compile (PyTorch 2)`: Wraps the pruned model in `torch.compile` when available.
73
+ - **Exports**:
74
+ - Classification: `pruned_model.ts`, `pruned_model.onnx`, `pruned_report.json`, `pruned_state_dict.pth`
75
+ - Segmentation: `pruned_seg_model.ts`, `pruned_seg_model.onnx`, `pruned_seg_report.json`, `pruned_seg_state_dict.pth`
76
+ - **Outputs**:
77
+ - Classification: Comparison metrics, Top-5 bar chart, per-layer sparsity table, download list
78
+ - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, per-layer sparsity table, download list
79
 
80
+ ### Quantization tab options (Classification & Segmentation)
81
  - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
82
  - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU; dynamic/weight-only runs force CPU execution for kernel support.
83
  - `Channels-last input (CUDA)`: Same as pruning; ignored on CPU.
84
  - `Mixed precision (AMP)`: Applies CUDA autocast to the quantized forward pass.
85
  - `Torch compile (PyTorch 2)`: Compiles the quantized model when available.
86
+ - **Exports**:
87
+ - Classification: `quantized_model.ts`, `quantized_model.onnx`, `quant_report.json`, `quantized_state_dict.pth`
88
+ - Segmentation: `quant_seg_model.ts`, `quant_seg_model.onnx`, `quant_seg_report.json`, `quant_seg_state_dict.pth`
89
+ - **Outputs**:
90
+ - Classification: Comparison metrics, Top-5 bar chart, download list
91
+ - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, download list
92
 
93
  ### What gets exported
94
  - Artifacts are written to `exports/`. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants.
 
96
  - State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere.
97
 
98
  ### Output Interpreting Tips
99
+ - **Top-1 Prediction (Classification)**: Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g., `chambered nautilus, pearly nautilus`).
100
+ - **Mask Agreement (Segmentation)**: Percentage of pixels where original and optimized models predict the same class. 100% means identical masks; lower values indicate divergence.
101
+ - **Class Distribution (Segmentation)**: Shows the top 25 most prevalent classes by pixel coverage, with percentages and counts for both models.
102
+ - **Image Sliders (Segmentation)**: Drag the slider to compare original vs. optimized overlays or raw masks side-by-side.
103
  - **Latency (ms)**: Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model executionβ€”see [Performance Notes](#performance-notes).
104
  - **Model Size (MB)**: Serialized state dictionary size after saving to disk.
105
 
 
110
  - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
111
 
112
  ## Extending the Lab
113
+ - **Classification**: Swap in different architectures by changing the `timm.create_model` call in `app.py`.
114
+ - **Segmentation**: Add new models from the [smp-hub](https://huggingface.co/smp-hub) collection by adding entries to `SEGMENTATION_MODEL_CONFIGS`.
115
  - Add calibration data and static INT8 quantization to include convolution layers.
116
  - Cache optimized models to avoid recomputation across requests.
117
+ - Integrate evaluation datasets to quantify accuracy drop (classification: top-1/top-5, segmentation: mIoU, pixel accuracy).
118
 
119
  ## CLI Mode
120
  - Run without the UI: `python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto`
app.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import os
5
  import time
6
  from pathlib import Path
 
7
 
8
  import matplotlib.pyplot as plt
9
  import gradio as gr
@@ -13,8 +14,13 @@ import timm
13
  import torch
14
  import torch.nn as nn
15
  import torch.nn.utils.prune as prune
16
- from PIL import Image
 
17
  from torchvision import transforms
 
 
 
 
18
 
19
 
20
  # ---------------------------------------------
@@ -60,6 +66,103 @@ PRESETS = {
60
  _MODEL_CACHE: dict[str, torch.nn.Module] = {}
61
  _TRANSFORM_CACHE: dict[str, transforms.Compose] = {}
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def select_device(device_str: str) -> torch.device:
65
  """Return a valid torch.device based on user selection."""
@@ -116,6 +219,267 @@ def clone_model(model_name: str):
116
  return fresh
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # ---------------------------------------------
120
  # Image Preprocess
121
  # ---------------------------------------------
@@ -514,10 +878,314 @@ def run_quantized(
514
 
515
  print("=== RUN QUANTIZED COMPLETE ===")
516
  return metrics_df, chart_fig, downloads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  # ---------------------------------------------
518
  # GRADIO UI
519
  # ---------------------------------------------
520
  examples = [["examples/cat.jpg"], ["examples/dog.jpg"], ["examples/bird.jpg"], ["examples/car.jpg"], ["examples/elephant.jpg"]]
 
521
 
522
 
523
  def create_demo():
@@ -530,10 +1198,11 @@ def create_demo():
530
  if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
531
  device_opts.append("mps")
532
  preset_opts = list(PRESETS.keys()) + ["custom"]
 
533
 
534
  with gr.Tabs():
535
  # ---- PRUNING TAB ----
536
- with gr.Tab("Pruning"):
537
  with gr.Row():
538
  with gr.Column():
539
  img_p = gr.Image(label="Upload Image")
@@ -551,13 +1220,34 @@ def create_demo():
551
  btn_p = gr.Button("Run Pruned Model")
552
  gr.Examples(examples=examples, inputs=img_p)
553
  gr.Markdown(
554
- "**Option Guide**\n"
555
- "- Base Model: select the timm architecture to optimize (pretrained when available).\n"
556
- "- Hardware Preset: load device, precision, and pruning defaults for common targets; choose custom to tweak manually.\n"
557
- "- Pruning Method/Amount: set structured vs unstructured pruning and the fraction of weights removed.\n"
558
- "- Device & CUDA Toggles: force CPU/CUDA/MPS and optionally enable channels-last or AMP for CUDA speedups.\n"
559
- "- Torch compile: wrap the model with torch.compile (PyTorch 2) to experiment with graph optimizations.\n"
560
- "- Export options: drop TorchScript, ONNX, and JSON reports into the `exports/` directory."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  )
562
 
563
  with gr.Column():
@@ -587,7 +1277,7 @@ def create_demo():
587
  )
588
 
589
  # ---- QUANTIZATION TAB ----
590
- with gr.Tab("Quantization"):
591
  with gr.Row():
592
  with gr.Column():
593
  img_q = gr.Image(label="Upload Image")
@@ -604,12 +1294,30 @@ def create_demo():
604
  btn_q = gr.Button("Run Quantized Model")
605
  gr.Examples(examples=examples, inputs=img_q)
606
  gr.Markdown(
607
- "**Option Guide**\n"
608
- "- Base Model & Preset: pick the architecture and optional hardware profile to prefill device and quant settings.\n"
609
- "- Quantization Type: `dynamic` applies post-training int8 to linear layers (forces CPU kernels), `weight_only` stores int8 weights with fp32 activations for a lighter CPU model, while `fp16` casts the full network to half precision for GPUs with native fp16 support.\n"
610
- "- Device & CUDA Toggles: run on CPU/CUDA/MPS; channels-last and AMP only benefit CUDA workloads.\n"
611
- "- Torch compile: try PyTorch 2 compile for extra speed when supported.\n"
612
- "- Export options: generate TorchScript, ONNX, and JSON artifacts inside `exports/`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  )
614
 
615
 
@@ -637,6 +1345,192 @@ def create_demo():
637
  outputs=[metrics_q, chart_q, downloads_q],
638
  )
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  return demo
641
 
642
 
 
4
  import os
5
  import time
6
  from pathlib import Path
7
+ from dataclasses import dataclass
8
 
9
  import matplotlib.pyplot as plt
10
  import gradio as gr
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.utils.prune as prune
17
+ import segmentation_models_pytorch as smp
18
+ from PIL import Image, ImageDraw, ImageFont
19
  from torchvision import transforms
20
+ try:
21
+ import albumentations as A
22
+ except ModuleNotFoundError: # pragma: no cover - optional dependency
23
+ A = None
24
 
25
 
26
  # ---------------------------------------------
 
66
  _MODEL_CACHE: dict[str, torch.nn.Module] = {}
67
  _TRANSFORM_CACHE: dict[str, transforms.Compose] = {}
68
 
69
+ @dataclass(frozen=True)
70
+ class SegmentationModelConfig:
71
+ name: str
72
+ checkpoint: str
73
+ classes: int = 150
74
+ dataset: str = "ADE20K"
75
+
76
+
77
+ SEGMENTATION_MODEL_CONFIGS: tuple[SegmentationModelConfig, ...] = (
78
+ SegmentationModelConfig("SegFormer B0 (ADE20K 512x512)", "smp-hub/segformer-b0-512x512-ade-160k"),
79
+ SegmentationModelConfig("SegFormer B4 (ADE20K 512x512)", "smp-hub/segformer-b4-512x512-ade-160k"),
80
+ SegmentationModelConfig("DPT Large (ADE20K)", "smp-hub/dpt-large-ade20k"),
81
+ SegmentationModelConfig("UPerNet ConvNeXt-Tiny (ADE20K)", "smp-hub/upernet-convnext-tiny"),
82
+ )
83
+ SEGMENTATION_MODEL_MAP = {cfg.name: cfg for cfg in SEGMENTATION_MODEL_CONFIGS}
84
+
85
+ _SEG_BASE_PALETTE = np.array(
86
+ [
87
+ [0, 0, 0],
88
+ [0, 114, 189],
89
+ [217, 83, 25],
90
+ [237, 177, 32],
91
+ [126, 47, 142],
92
+ [119, 172, 48],
93
+ [77, 190, 238],
94
+ [162, 20, 47],
95
+ [163, 200, 236],
96
+ [255, 127, 14],
97
+ [255, 188, 121],
98
+ [111, 118, 207],
99
+ [204, 121, 167],
100
+ [148, 103, 189],
101
+ [44, 160, 44],
102
+ [23, 190, 207],
103
+ [31, 119, 180],
104
+ [255, 152, 150],
105
+ [214, 39, 40],
106
+ [188, 189, 34],
107
+ ],
108
+ dtype=np.uint8,
109
+ )
110
+
111
+ _SEG_MODEL_CACHE: dict[str, torch.nn.Module] = {}
112
+ _SEG_TRANSFORM_CACHE: dict[str, object] = {}
113
+ _SEG_PALETTE_CACHE: dict[int, np.ndarray] = {}
114
+
115
+ ADE20K_CLASS_NAMES = [
116
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed", "windowpane", "grass",
117
+ "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair",
118
+ "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field",
119
+ "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion",
120
+ "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace",
121
+ "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway",
122
+ "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench",
123
+ "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel",
124
+ "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver",
125
+ "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet",
126
+ "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool",
127
+ "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball",
128
+ "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher",
129
+ "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan",
130
+ "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"
131
+ ]
132
+
133
+
134
+ def add_image_label(img: Image.Image, label: str) -> Image.Image:
135
+ """Add a text label at the top of an image."""
136
+ img_array = np.array(img)
137
+ h, w = img_array.shape[:2]
138
+
139
+ # Create canvas with extra space at top for label
140
+ canvas = np.ones((h + 40, w, 3), dtype=np.uint8) * 255
141
+ canvas[40:, :] = img_array
142
+
143
+ # Convert back to PIL for text drawing
144
+ canvas_img = Image.fromarray(canvas)
145
+ draw = ImageDraw.Draw(canvas_img)
146
+
147
+ # Try to use a nice font, fall back to default if not available
148
+ try:
149
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
150
+ except:
151
+ try:
152
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 20)
153
+ except:
154
+ font = ImageFont.load_default()
155
+
156
+ # Get text size and center it
157
+ bbox = draw.textbbox((0, 0), label, font=font)
158
+ text_width = bbox[2] - bbox[0]
159
+ text_x = (w - text_width) // 2
160
+
161
+ # Draw text
162
+ draw.text((text_x, 10), label, fill=(0, 0, 0), font=font)
163
+
164
+ return canvas_img
165
+
166
 
167
  def select_device(device_str: str) -> torch.device:
168
  """Return a valid torch.device based on user selection."""
 
219
  return fresh
220
 
221
 
222
+ # ---------------------------------------------
223
+ # Segmentation Utilities
224
+ # ---------------------------------------------
225
+ def _require_albumentations():
226
+ if A is None:
227
+ raise RuntimeError(
228
+ "Albumentations is required for pretrained segmentation models. "
229
+ "Install it with `pip install albumentations` or add it to your environment."
230
+ )
231
+
232
+
233
+ def get_segmentation_model(config: SegmentationModelConfig) -> nn.Module:
234
+ key = config.checkpoint
235
+ if key not in _SEG_MODEL_CACHE:
236
+ model = smp.from_pretrained(config.checkpoint).eval()
237
+ _SEG_MODEL_CACHE[key] = model
238
+ return _SEG_MODEL_CACHE[key]
239
+
240
+
241
+ def clone_segmentation_model(config: SegmentationModelConfig) -> nn.Module:
242
+ base = get_segmentation_model(config)
243
+ fresh = smp.from_pretrained(config.checkpoint).eval()
244
+ fresh.load_state_dict(base.state_dict())
245
+ return fresh
246
+
247
+
248
+ def get_segmentation_transform(config: SegmentationModelConfig):
249
+ key = config.checkpoint
250
+ if key in _SEG_TRANSFORM_CACHE:
251
+ return _SEG_TRANSFORM_CACHE[key]
252
+
253
+ _require_albumentations()
254
+ try:
255
+ preprocessing = A.Compose.from_pretrained(config.checkpoint)
256
+ except Exception as exc: # pragma: no cover - depends on network availability
257
+ raise RuntimeError(f"Failed to load preprocessing pipeline for {config.checkpoint}: {exc}") from exc
258
+
259
+ def _transform(image):
260
+ if image is None:
261
+ raise ValueError("No image provided")
262
+ if not isinstance(image, Image.Image):
263
+ if isinstance(image, np.ndarray):
264
+ array = image
265
+ if array.dtype != np.uint8:
266
+ array = (np.clip(array, 0, 1) * 255).astype(np.uint8)
267
+ image_rgb = Image.fromarray(array)
268
+ else:
269
+ raise ValueError(f"Unsupported image type: {type(image)}")
270
+ else:
271
+ image_rgb = image
272
+
273
+ image_rgb = image_rgb.convert("RGB")
274
+ np_image = np.array(image_rgb)
275
+ processed = preprocessing(image=np_image)["image"]
276
+ if isinstance(processed, torch.Tensor):
277
+ processed_np = processed.detach().cpu().numpy()
278
+ else:
279
+ processed_np = np.asarray(processed, dtype=np.float32)
280
+ tensor = torch.from_numpy(processed_np.transpose(2, 0, 1)).float()
281
+ return tensor, image_rgb
282
+
283
+ _SEG_TRANSFORM_CACHE[key] = _transform
284
+ return _transform
285
+
286
+
287
+ def get_segmentation_palette(class_count: int) -> np.ndarray:
288
+ if class_count in _SEG_PALETTE_CACHE:
289
+ return _SEG_PALETTE_CACHE[class_count]
290
+
291
+ base_len = len(_SEG_BASE_PALETTE)
292
+ if class_count <= base_len:
293
+ palette = _SEG_BASE_PALETTE[:class_count]
294
+ else:
295
+ palette = np.zeros((class_count, 3), dtype=np.uint8)
296
+ palette[:base_len] = _SEG_BASE_PALETTE
297
+ rng = np.random.default_rng(1337)
298
+ palette[base_len:] = rng.integers(0, 256, size=(class_count - base_len, 3), endpoint=False, dtype=np.uint8)
299
+ palette[:, 0] |= 1 # ensure colors are not pure black except index 0
300
+ palette[0] = np.array([0, 0, 0], dtype=np.uint8)
301
+
302
+ _SEG_PALETTE_CACHE[class_count] = palette
303
+ return palette
304
+
305
+
306
+ def colorize_mask(mask: np.ndarray, class_count: int) -> Image.Image:
307
+ if mask.ndim != 2:
308
+ raise ValueError("Mask must be 2D for colorization")
309
+ palette = get_segmentation_palette(class_count)
310
+ indexed = np.mod(mask, class_count)
311
+ colored = palette[indexed]
312
+ return Image.fromarray(colored.astype(np.uint8))
313
+
314
+
315
+ def overlay_mask(image: Image.Image, mask_image: Image.Image, alpha: float = 0.5) -> Image.Image:
316
+ base = np.array(image.convert("RGB"), dtype=np.float32)
317
+ mask_resized = mask_image.resize(image.size, Image.NEAREST)
318
+ mask_arr = np.array(mask_resized, dtype=np.float32)
319
+ blended = (1.0 - alpha) * base + alpha * mask_arr
320
+ return Image.fromarray(np.clip(blended, 0, 255).astype(np.uint8))
321
+
322
+
323
+ def summarize_mask(mask: np.ndarray, class_count: int) -> list[dict[str, float]]:
324
+ flat = mask.reshape(-1)
325
+ counts = np.bincount(flat, minlength=class_count)
326
+ total = float(flat.size)
327
+ summary = []
328
+ for idx in range(class_count):
329
+ count = int(counts[idx])
330
+ percent = (count / total * 100.0) if total else 0.0
331
+ summary.append({"index": idx, "count": count, "percent": percent})
332
+ return summary
333
+
334
+
335
+ def get_class_labels(config: SegmentationModelConfig) -> list[str]:
336
+ # Try to get labels from model metadata first
337
+ model = get_segmentation_model(config)
338
+ meta = getattr(model, "meta", {}) or {}
339
+ dataset_meta = meta.get("dataset", {}) or {}
340
+ labels = dataset_meta.get("class_names") or dataset_meta.get("classes_names")
341
+
342
+ # If not in metadata, use dataset-specific labels
343
+ if not labels:
344
+ if config.dataset == "ADE20K" and config.classes == 150:
345
+ labels = ADE20K_CLASS_NAMES
346
+ else:
347
+ labels = [f"Class {idx}" for idx in range(config.classes)]
348
+ else:
349
+ labels = list(labels)
350
+
351
+ # Ensure we have the right number of labels
352
+ if len(labels) < config.classes:
353
+ labels.extend(f"Class {len(labels) + i}" for i in range(config.classes - len(labels)))
354
+ return labels[: config.classes]
355
+
356
+
357
+ def run_segmentation_inference(
358
+ model: nn.Module,
359
+ image,
360
+ device: torch.device,
361
+ transform_fn,
362
+ channels_last: bool,
363
+ warmup: bool,
364
+ use_amp: bool,
365
+ class_count: int,
366
+ ):
367
+ tensor, original_image = transform_fn(image)
368
+
369
+ model = model.to(device)
370
+ input_tensor = tensor.unsqueeze(0).to(device)
371
+
372
+ if channels_last and device.type == "cuda":
373
+ input_tensor = input_tensor.to(memory_format=torch.channels_last)
374
+
375
+ if next(model.parameters()).dtype == torch.float16:
376
+ input_tensor = input_tensor.half()
377
+
378
+ if warmup:
379
+ with torch.no_grad():
380
+ model(input_tensor)
381
+
382
+ amp_ctx = torch.cuda.amp.autocast(enabled=use_amp and device.type == "cuda")
383
+ start = time.time()
384
+ with torch.no_grad(), amp_ctx:
385
+ logits = model(input_tensor)
386
+ latency = (time.time() - start) * 1000
387
+
388
+ if isinstance(logits, (list, tuple)):
389
+ logits = logits[0]
390
+
391
+ logits = logits.detach().cpu()
392
+ probs = torch.softmax(logits, dim=1)
393
+ mask_tensor = torch.argmax(probs, dim=1)[0]
394
+ mask_processed = mask_tensor.cpu().numpy().astype(np.int64)
395
+
396
+ mean_conf = float(probs.max(dim=1)[0].mean().item())
397
+
398
+ mask_processed_image = colorize_mask(mask_processed, class_count)
399
+ mask_original_l = Image.fromarray(mask_processed.astype(np.uint8), mode="L").resize(original_image.size, Image.NEAREST)
400
+ mask_original_np = np.array(mask_original_l, dtype=np.int64)
401
+ mask_original_image = colorize_mask(mask_original_np, class_count)
402
+ overlay_original = overlay_mask(original_image, mask_original_image)
403
+ class_summary = summarize_mask(mask_original_np, class_count)
404
+
405
+ return {
406
+ "latency": latency,
407
+ "mask_processed": mask_processed,
408
+ "mask_original": mask_original_np,
409
+ "mask_image_processed": mask_processed_image,
410
+ "mask_image_original": mask_original_image,
411
+ "overlay_original": overlay_original,
412
+ "mean_confidence": mean_conf,
413
+ "class_summary": class_summary,
414
+ }
415
+
416
+
417
+ def build_segmentation_metrics(
418
+ original_result: dict,
419
+ optimized_result: dict,
420
+ size_original: float,
421
+ size_optimized: float,
422
+ optimized_label: str,
423
+ ) -> pd.DataFrame:
424
+ mask_original = original_result["mask_original"]
425
+ mask_optimized = optimized_result["mask_original"]
426
+ agreement = float((mask_original == mask_optimized).mean() * 100.0)
427
+
428
+ metrics_df = pd.DataFrame(
429
+ {
430
+ "Metric": [
431
+ "Latency (ms)",
432
+ "Mean Confidence",
433
+ "Model Size (MB)",
434
+ "Mask Agreement (%)",
435
+ ],
436
+ "Original Model": [
437
+ f"{original_result['latency']:.2f}",
438
+ f"{original_result['mean_confidence']:.4f}",
439
+ f"{size_original:.2f}",
440
+ "100.00",
441
+ ],
442
+ optimized_label: [
443
+ f"{optimized_result['latency']:.2f}",
444
+ f"{optimized_result['mean_confidence']:.4f}",
445
+ f"{size_optimized:.2f}",
446
+ f"{agreement:.2f}",
447
+ ],
448
+ }
449
+ )
450
+ return metrics_df
451
+
452
+
453
+ def build_class_distribution_df(
454
+ original_summary: list[dict[str, float]],
455
+ optimized_summary: list[dict[str, float]],
456
+ labels: list[str],
457
+ optimized_label: str,
458
+ max_rows: int = 25,
459
+ ) -> pd.DataFrame:
460
+ rows = []
461
+ for idx, label in enumerate(labels):
462
+ orig_entry = original_summary[idx]
463
+ opt_entry = optimized_summary[idx]
464
+ if orig_entry["count"] == 0 and opt_entry["count"] == 0:
465
+ continue
466
+ rows.append(
467
+ {
468
+ "Class": label,
469
+ "Original %": round(orig_entry["percent"], 2),
470
+ f"{optimized_label} %": round(opt_entry["percent"], 2),
471
+ "Original Pixels": orig_entry["count"],
472
+ f"{optimized_label} Pixels": opt_entry["count"],
473
+ }
474
+ )
475
+
476
+ rows.sort(key=lambda item: max(item["Original %"], item[f"{optimized_label} %"]), reverse=True)
477
+ if max_rows and len(rows) > max_rows:
478
+ rows = rows[:max_rows]
479
+
480
+ return pd.DataFrame(rows)
481
+
482
+
483
  # ---------------------------------------------
484
  # Image Preprocess
485
  # ---------------------------------------------
 
878
 
879
  print("=== RUN QUANTIZED COMPLETE ===")
880
  return metrics_df, chart_fig, downloads
881
+
882
+
883
+ def run_pruned_segmentation(
884
+ img,
885
+ model_choice,
886
+ method,
887
+ amount,
888
+ device_choice="auto",
889
+ channels_last=False,
890
+ use_compile=False,
891
+ use_amp=False,
892
+ export_ts=False,
893
+ export_onnx=False,
894
+ export_report=False,
895
+ export_state=True,
896
+ preset=None,
897
+ ):
898
+ print("\n=== RUN SEGMENTATION PRUNED CALLED ===")
899
+ if img is None:
900
+ print("ERROR: Image is None")
901
+ empty_metrics = pd.DataFrame({"Metric": ["Error"], "Original Model": ["No image"], "Pruned Model": [""]})
902
+ empty_dist = pd.DataFrame({"Class": [], "Original %": [], "Pruned %": []})
903
+ return empty_metrics, empty_dist, None, None, pd.DataFrame(), []
904
+
905
+ config = SEGMENTATION_MODEL_MAP.get(model_choice, SEGMENTATION_MODEL_CONFIGS[0])
906
+
907
+ if preset in PRESETS:
908
+ preset_cfg = PRESETS[preset]
909
+ device_choice = preset_cfg["device"]
910
+ channels_last = preset_cfg["channels_last"]
911
+ use_compile = preset_cfg["compile"]
912
+ use_amp = preset_cfg.get("amp", use_amp)
913
+ amount = preset_cfg.get("prune_amount", amount)
914
+
915
+ device = select_device(device_choice)
916
+
917
+ base_model = get_segmentation_model(config)
918
+ transform_fn = get_segmentation_transform(config)
919
+ class_labels = get_class_labels(config)
920
+ class_count = config.classes
921
+
922
+ original_result = run_segmentation_inference(
923
+ base_model,
924
+ img,
925
+ device,
926
+ transform_fn,
927
+ channels_last=channels_last,
928
+ warmup=True,
929
+ use_amp=use_amp,
930
+ class_count=class_count,
931
+ )
932
+
933
+ fresh_model = clone_segmentation_model(config)
934
+ pruned_model = apply_pruning(fresh_model, amount=float(amount), method=method)
935
+ pruned_model = maybe_compile(pruned_model, use_compile)
936
+ pruned_result = run_segmentation_inference(
937
+ pruned_model,
938
+ img,
939
+ device,
940
+ transform_fn,
941
+ channels_last=channels_last,
942
+ warmup=True,
943
+ use_amp=use_amp,
944
+ class_count=class_count,
945
+ )
946
+
947
+ size_orig = get_state_dict_size_mb(base_model)
948
+ size_pruned = get_state_dict_size_mb(pruned_model)
949
+ metrics_df = build_segmentation_metrics(original_result, pruned_result, size_orig, size_pruned, "Pruned Model")
950
+ class_df = build_class_distribution_df(
951
+ original_result["class_summary"],
952
+ pruned_result["class_summary"],
953
+ class_labels,
954
+ "Pruned",
955
+ )
956
+
957
+ # Add labels to images for slider comparison
958
+ overlay_orig_labeled = add_image_label(original_result["overlay_original"], "Original Model")
959
+ overlay_pruned_labeled = add_image_label(pruned_result["overlay_original"], "Pruned Model")
960
+ mask_orig_labeled = add_image_label(original_result["mask_image_original"], "Original Mask")
961
+ mask_pruned_labeled = add_image_label(pruned_result["mask_image_original"], "Pruned Mask")
962
+
963
+ overlay_slider_value = (
964
+ overlay_orig_labeled,
965
+ overlay_pruned_labeled,
966
+ )
967
+ mask_slider_value = (
968
+ mask_orig_labeled,
969
+ mask_pruned_labeled,
970
+ )
971
+ sparsity_df = compute_sparsity(pruned_model.cpu())
972
+
973
+ downloads: list[str] = []
974
+ export_dir = Path("exports")
975
+ export_dir.mkdir(exist_ok=True)
976
+
977
+ if export_report:
978
+ report_path = export_dir / "pruned_seg_report.json"
979
+ report = {
980
+ "model": config.name,
981
+ "checkpoint": config.checkpoint,
982
+ "dataset": config.dataset,
983
+ "pruning": {"method": method, "amount": float(amount)},
984
+ "metrics": metrics_df.to_dict(),
985
+ "class_distribution": class_df.to_dict(),
986
+ }
987
+ report_path.write_text(json.dumps(report, indent=2))
988
+ downloads.append(str(report_path))
989
+
990
+ if export_state:
991
+ state_path = export_dir / "pruned_seg_state_dict.pth"
992
+ torch.save(pruned_model.state_dict(), state_path)
993
+ downloads.append(str(state_path))
994
+
995
+ sample_tensor, _ = transform_fn(img)
996
+ sample_batch = sample_tensor.unsqueeze(0)
997
+
998
+ if export_ts:
999
+ ts_path = export_dir / "pruned_seg_model.ts"
1000
+ try:
1001
+ scripted = torch.jit.trace(pruned_model.cpu(), sample_batch)
1002
+ scripted.save(ts_path)
1003
+ downloads.append(str(ts_path))
1004
+ except Exception as exc: # pragma: no cover - export best effort
1005
+ print(f"TorchScript export failed: {exc}")
1006
+
1007
+ if export_onnx:
1008
+ onnx_path = export_dir / "pruned_seg_model.onnx"
1009
+ try:
1010
+ torch.onnx.export(
1011
+ pruned_model.cpu(),
1012
+ sample_batch,
1013
+ onnx_path,
1014
+ input_names=["input"],
1015
+ output_names=["mask"],
1016
+ opset_version=13,
1017
+ dynamic_axes={"input": {0: "batch"}, "mask": {0: "batch"}},
1018
+ )
1019
+ downloads.append(str(onnx_path))
1020
+ except Exception as exc: # pragma: no cover - export best effort
1021
+ print(f"ONNX export failed: {exc}")
1022
+
1023
+ return (
1024
+ metrics_df,
1025
+ class_df,
1026
+ overlay_slider_value,
1027
+ mask_slider_value,
1028
+ sparsity_df,
1029
+ downloads,
1030
+ )
1031
+
1032
+
1033
+ def run_quantized_segmentation(
1034
+ img,
1035
+ model_choice,
1036
+ q_type,
1037
+ device_choice="auto",
1038
+ channels_last=False,
1039
+ use_compile=False,
1040
+ use_amp=False,
1041
+ export_ts=False,
1042
+ export_onnx=False,
1043
+ export_report=False,
1044
+ export_state=True,
1045
+ preset=None,
1046
+ ):
1047
+ print("\n=== RUN SEGMENTATION QUANTIZED CALLED ===")
1048
+ if img is None:
1049
+ print("ERROR: Image is None")
1050
+ empty_metrics = pd.DataFrame({"Metric": ["Error"], "Original Model": ["No image"], "Quantized Model": [""]})
1051
+ empty_dist = pd.DataFrame({"Class": [], "Original %": [], "Quantized %": []})
1052
+ return empty_metrics, empty_dist, None, None, []
1053
+
1054
+ config = SEGMENTATION_MODEL_MAP.get(model_choice, SEGMENTATION_MODEL_CONFIGS[0])
1055
+
1056
+ if preset in PRESETS:
1057
+ preset_cfg = PRESETS[preset]
1058
+ device_choice = preset_cfg["device"]
1059
+ channels_last = preset_cfg["channels_last"]
1060
+ use_compile = preset_cfg["compile"]
1061
+ use_amp = preset_cfg.get("amp", use_amp)
1062
+ q_type = preset_cfg.get("quant", q_type)
1063
+
1064
+ device = select_device(device_choice)
1065
+ if q_type in {"dynamic", "weight_only"} and device.type != "cpu":
1066
+ print("Dynamic quantization runs on CPU; switching device to CPU.")
1067
+ device = torch.device("cpu")
1068
+ channels_last = False
1069
+ use_amp = False
1070
+
1071
+ base_model = get_segmentation_model(config)
1072
+ transform_fn = get_segmentation_transform(config)
1073
+ class_labels = get_class_labels(config)
1074
+ class_count = config.classes
1075
+
1076
+ original_result = run_segmentation_inference(
1077
+ base_model,
1078
+ img,
1079
+ device,
1080
+ transform_fn,
1081
+ channels_last=channels_last,
1082
+ warmup=True,
1083
+ use_amp=use_amp,
1084
+ class_count=class_count,
1085
+ )
1086
+
1087
+ fresh_model = clone_segmentation_model(config)
1088
+ quant_model = apply_quantization(fresh_model, q_type)
1089
+ quant_model = maybe_compile(quant_model, use_compile)
1090
+
1091
+ quant_result = run_segmentation_inference(
1092
+ quant_model,
1093
+ img,
1094
+ device,
1095
+ transform_fn,
1096
+ channels_last=channels_last,
1097
+ warmup=True,
1098
+ use_amp=use_amp,
1099
+ class_count=class_count,
1100
+ )
1101
+
1102
+ size_orig = get_state_dict_size_mb(base_model)
1103
+ size_quant = get_state_dict_size_mb(quant_model)
1104
+ metrics_df = build_segmentation_metrics(original_result, quant_result, size_orig, size_quant, "Quantized Model")
1105
+ class_df = build_class_distribution_df(
1106
+ original_result["class_summary"],
1107
+ quant_result["class_summary"],
1108
+ class_labels,
1109
+ "Quantized",
1110
+ )
1111
+
1112
+ # Add labels to images for slider comparison
1113
+ overlay_orig_labeled = add_image_label(original_result["overlay_original"], "Original Model")
1114
+ overlay_quant_labeled = add_image_label(quant_result["overlay_original"], "Quantized Model")
1115
+ mask_orig_labeled = add_image_label(original_result["mask_image_original"], "Original Mask")
1116
+ mask_quant_labeled = add_image_label(quant_result["mask_image_original"], "Quantized Mask")
1117
+
1118
+ overlay_slider_value = (
1119
+ overlay_orig_labeled,
1120
+ overlay_quant_labeled,
1121
+ )
1122
+ mask_slider_value = (
1123
+ mask_orig_labeled,
1124
+ mask_quant_labeled,
1125
+ )
1126
+
1127
+ downloads: list[str] = []
1128
+ export_dir = Path("exports")
1129
+ export_dir.mkdir(exist_ok=True)
1130
+
1131
+ if export_report:
1132
+ report_path = export_dir / "quant_seg_report.json"
1133
+ report = {
1134
+ "model": config.name,
1135
+ "checkpoint": config.checkpoint,
1136
+ "dataset": config.dataset,
1137
+ "quantization": q_type,
1138
+ "metrics": metrics_df.to_dict(),
1139
+ "class_distribution": class_df.to_dict(),
1140
+ }
1141
+ report_path.write_text(json.dumps(report, indent=2))
1142
+ downloads.append(str(report_path))
1143
+
1144
+ if export_state:
1145
+ state_path = export_dir / "quant_seg_state_dict.pth"
1146
+ torch.save(quant_model.state_dict(), state_path)
1147
+ downloads.append(str(state_path))
1148
+
1149
+ sample_tensor, _ = transform_fn(img)
1150
+ sample_batch = sample_tensor.unsqueeze(0)
1151
+
1152
+ if export_ts:
1153
+ ts_path = export_dir / "quant_seg_model.ts"
1154
+ try:
1155
+ scripted = torch.jit.trace(quant_model.cpu(), sample_batch)
1156
+ scripted.save(ts_path)
1157
+ downloads.append(str(ts_path))
1158
+ except Exception as exc: # pragma: no cover - export best effort
1159
+ print(f"TorchScript export failed: {exc}")
1160
+
1161
+ if export_onnx:
1162
+ onnx_path = export_dir / "quant_seg_model.onnx"
1163
+ try:
1164
+ torch.onnx.export(
1165
+ quant_model.cpu(),
1166
+ sample_batch,
1167
+ onnx_path,
1168
+ input_names=["input"],
1169
+ output_names=["mask"],
1170
+ opset_version=13,
1171
+ dynamic_axes={"input": {0: "batch"}, "mask": {0: "batch"}},
1172
+ )
1173
+ downloads.append(str(onnx_path))
1174
+ except Exception as exc: # pragma: no cover - export best effort
1175
+ print(f"ONNX export failed: {exc}")
1176
+
1177
+ return (
1178
+ metrics_df,
1179
+ class_df,
1180
+ overlay_slider_value,
1181
+ mask_slider_value,
1182
+ downloads,
1183
+ )
1184
  # ---------------------------------------------
1185
  # GRADIO UI
1186
  # ---------------------------------------------
1187
  examples = [["examples/cat.jpg"], ["examples/dog.jpg"], ["examples/bird.jpg"], ["examples/car.jpg"], ["examples/elephant.jpg"]]
1188
+ ade_examples = [["examples/ADE_val_00000001.jpg"], ["examples/ADE_val_00000002.jpg"], ["examples/ADE_val_00001001.jpg"], ["examples/ADE_val_00001842.jpg"]]
1189
 
1190
 
1191
  def create_demo():
 
1198
  if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
1199
  device_opts.append("mps")
1200
  preset_opts = list(PRESETS.keys()) + ["custom"]
1201
+ seg_model_options = [cfg.name for cfg in SEGMENTATION_MODEL_CONFIGS]
1202
 
1203
  with gr.Tabs():
1204
  # ---- PRUNING TAB ----
1205
+ with gr.Tab("Pruning-Classification"):
1206
  with gr.Row():
1207
  with gr.Column():
1208
  img_p = gr.Image(label="Upload Image")
 
1220
  btn_p = gr.Button("Run Pruned Model")
1221
  gr.Examples(examples=examples, inputs=img_p)
1222
  gr.Markdown(
1223
+ "### πŸ“š Classification Pruning Guide\n\n"
1224
+ "**What is Pruning?**\n"
1225
+ "Pruning removes less important weights from neural networks to reduce model size and potentially improve inference speed. "
1226
+ "This tab applies pruning to ImageNet classification models.\n\n"
1227
+ "**Options Explained:**\n"
1228
+ "- **Base Model**: Select from 7 pretrained architectures (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-Base, RegNetY-016, EfficientNet-Lite0). Each has different size/accuracy tradeoffs.\n"
1229
+ "- **Hardware Preset**: Quick configurations for common deployment scenarios:\n"
1230
+ " - *Edge CPU*: Optimized for resource-constrained devices (CPU-only, 30% pruning, dynamic quantization)\n"
1231
+ " - *Datacenter GPU*: Maximum performance on modern GPUs (CUDA, channels-last, compile, 20% pruning)\n"
1232
+ " - *Apple MPS*: Tuned for Apple Silicon (M1/M2/M3 chips with Metal Performance Shaders)\n"
1233
+ " - *Custom*: Manual control over all settings\n"
1234
+ "- **Pruning Method**:\n"
1235
+ " - *Structured*: Removes entire filters/channels; better hardware support and actual speedups\n"
1236
+ " - *Unstructured*: Zeros individual weights; higher compression but needs specialized sparse kernels for speedup\n"
1237
+ "- **Pruning Amount**: Percentage of weights to remove (0.1 = 10%, 0.9 = 90%). Higher values = smaller model but potential accuracy loss.\n"
1238
+ "- **Device**: Inference hardware (auto-detects best available: CUDA β†’ MPS β†’ CPU)\n"
1239
+ "- **Channels-last (CUDA only)**: Memory layout optimization for faster convolution operations on NVIDIA GPUs\n"
1240
+ "- **Mixed Precision (AMP)**: Uses FP16 where safe, FP32 where needed; faster on modern GPUs with Tensor Cores\n"
1241
+ "- **Torch Compile**: PyTorch 2.0+ graph optimization; can provide 20-40% speedup but adds compilation overhead\n\n"
1242
+ "**Export Options:**\n"
1243
+ "- *TorchScript*: Serialized model for C++ deployment or production serving\n"
1244
+ "- *ONNX*: Cross-framework format (TensorRT, OpenVINO, ONNX Runtime, CoreML)\n"
1245
+ "- *JSON Report*: Detailed metrics, settings, and Top-5 predictions for both models\n"
1246
+ "- *State Dict*: Always saved; PyTorch checkpoint for loading pruned weights later\n\n"
1247
+ "**Reading the Results:**\n"
1248
+ "- *Comparison Metrics*: Side-by-side accuracy, speed, and size\n"
1249
+ "- *Top-5 Chart*: Visual comparison of prediction confidence across models\n"
1250
+ "- *Layer Sparsity*: Per-layer breakdown showing which parts were pruned most"
1251
  )
1252
 
1253
  with gr.Column():
 
1277
  )
1278
 
1279
  # ---- QUANTIZATION TAB ----
1280
+ with gr.Tab("Quantization-Classification"):
1281
  with gr.Row():
1282
  with gr.Column():
1283
  img_q = gr.Image(label="Upload Image")
 
1294
  btn_q = gr.Button("Run Quantized Model")
1295
  gr.Examples(examples=examples, inputs=img_q)
1296
  gr.Markdown(
1297
+ "### πŸ“š Classification Quantization Guide\n\n"
1298
+ "**What is Quantization?**\n"
1299
+ "Quantization reduces model precision from 32-bit floats to lower bit-widths (INT8, FP16), decreasing memory usage and "
1300
+ "enabling faster inference on hardware with specialized low-precision instructions.\n\n"
1301
+ "**Options Explained:**\n"
1302
+ "- **Base Model**: Choose from 7 pretrained ImageNet classifiers with varying complexity.\n"
1303
+ "- **Hardware Preset**: Same presets as pruning tab, but with quantization-specific defaults.\n"
1304
+ "- **Quantization Type**:\n"
1305
+ " - *Dynamic*: Post-training INT8 quantization on linear layers; activations quantized dynamically at runtime. **Forces CPU** (PyTorch's INT8 kernels are CPU-only). Best for transformers and MLP-heavy models.\n"
1306
+ " - *Weight-only*: Stores weights as INT8, computes in FP32. Reduces memory bandwidth, smaller model files. **CPU-optimized**.\n"
1307
+ " - *FP16*: Half-precision floating point; requires GPU with FP16 support (CUDA, MPS). Minimal accuracy loss, ~2x speedup on modern GPUs.\n"
1308
+ "- **Device**: Hardware target (dynamic/weight-only auto-switch to CPU for kernel compatibility)\n"
1309
+ "- **Channels-last**: CUDA memory layout optimization (ignored on CPU)\n"
1310
+ "- **Mixed Precision (AMP)**: Can combine with FP16 quantization on GPUs\n"
1311
+ "- **Torch Compile**: Graph-level optimizations from PyTorch 2.0+\n\n"
1312
+ "**Export Options:** Same as pruning (TorchScript, ONNX, JSON report, state dict)\n\n"
1313
+ "**Important Notes:**\n"
1314
+ "⚠️ Dynamic/weight-only quantization automatically uses CPU even if GPU is selected (PyTorch limitation)\n"
1315
+ "⚠️ ResNet-50 and similar CNN-heavy models see modest INT8 speedups because only linear layers are quantized\n"
1316
+ "⚠️ FP16 on CPU often reverts to FP32 internally, adding overhead instead of speedup\n\n"
1317
+ "**Reading the Results:**\n"
1318
+ "- *Latency*: Dynamic quantization may show higher latency due to runtime overhead; production deployments should use cached models\n"
1319
+ "- *Model Size*: FP16 β‰ˆ 50% reduction, INT8 dynamic β‰ˆ 75% reduction (varies by architecture)\n"
1320
+ "- *Accuracy*: Watch for confidence drops; quantization can shift predictions slightly"
1321
  )
1322
 
1323
 
 
1345
  outputs=[metrics_q, chart_q, downloads_q],
1346
  )
1347
 
1348
+ # ---- SEGMENTATION PRUNING TAB ----
1349
+ with gr.Tab("Pruning-Segmentation"):
1350
+ with gr.Row():
1351
+ with gr.Column():
1352
+ img_sp = gr.Image(label="Upload Image")
1353
+ model_sp = gr.Dropdown(seg_model_options, value=seg_model_options[0], label="Pretrained ADE20K Model")
1354
+ preset_sp = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
1355
+ method_sp = gr.Dropdown(["unstructured", "structured"], value="structured", label="Pruning Method")
1356
+ amount_sp = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.4, label="Pruning Amount")
1357
+ device_sp = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
1358
+ channels_last_sp = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
1359
+ compile_sp = gr.Checkbox(label="Torch compile (PyTorch 2)")
1360
+ amp_sp = gr.Checkbox(label="Mixed precision (AMP)", value=True)
1361
+ export_ts_sp = gr.Checkbox(label="Export TorchScript")
1362
+ export_onnx_sp = gr.Checkbox(label="Export ONNX")
1363
+ export_report_sp = gr.Checkbox(label="Export JSON report", value=True)
1364
+ btn_sp = gr.Button("Run Segmentation Pruning")
1365
+ gr.Examples(examples=ade_examples, inputs=img_sp, label="ADE20K Samples")
1366
+ gr.Markdown(
1367
+ "### 🎨 Segmentation Pruning Guide\n\n"
1368
+ "**What is Semantic Segmentation?**\n"
1369
+ "Semantic segmentation assigns a class label to every pixel in an image (e.g., sky, road, person, car). "
1370
+ "This tab uses ADE20K-pretrained models that recognize 150 scene categories.\n\n"
1371
+ "**Available Models:**\n"
1372
+ "- **SegFormer B0** (512x512): Lightweight transformer-based segmenter; efficient for edge deployment\n"
1373
+ "- **SegFormer B4** (512x512): Larger variant with better accuracy; ~4x B0 parameters\n"
1374
+ "- **DPT Large**: Vision-transformer-based dense prediction; state-of-the-art accuracy but slower\n"
1375
+ "- **UPerNet ConvNeXt-Tiny**: Unified perceptual parsing with modern CNN backbone; balanced speed/accuracy\n\n"
1376
+ "**Segmentation-Specific Options:**\n"
1377
+ "- All pruning/device/compile options work the same as classification\n"
1378
+ "- Models use [smp-hub](https://huggingface.co/smp-hub) pretrained checkpoints via `segmentation-models-pytorch`\n"
1379
+ "- Preprocessing pipelines are model-specific (loaded from Hugging Face metadata)\n"
1380
+ "- Images are resized based on model training resolution (usually 512x512 or 640x640)\n\n"
1381
+ "**Understanding Segmentation Outputs:**\n"
1382
+ "1. **Comparison Metrics Table**:\n"
1383
+ " - *Latency*: Inference time for full-image segmentation\n"
1384
+ " - *Mean Confidence*: Average softmax probability across all pixels\n"
1385
+ " - *Model Size*: State dict size in MB\n"
1386
+ " - *Mask Agreement*: % of pixels with identical class predictions (100% = perfect match)\n"
1387
+ "2. **Class Distribution Table**:\n"
1388
+ " - Top 25 most prevalent classes by pixel coverage\n"
1389
+ " - Shows percentage and pixel counts for both models\n"
1390
+ " - Helps identify which objects dominate the scene\n"
1391
+ "3. **Overlay Comparison Slider**:\n"
1392
+ " - Original image blended with colored segmentation masks\n"
1393
+ " - Drag slider to compare original vs. pruned predictions\n"
1394
+ " - Colors map to specific ADE20K classes (150 categories)\n"
1395
+ "4. **Mask Comparison Slider**:\n"
1396
+ " - Raw segmentation masks without image overlay\n"
1397
+ " - Easier to spot subtle prediction differences\n"
1398
+ "5. **Layer Sparsity Table**:\n"
1399
+ " - Per-layer pruning statistics showing compression levels\n\n"
1400
+ "**Export Options:**\n"
1401
+ "Files saved with `_seg` suffix: `pruned_seg_model.ts`, `pruned_seg_report.json`, etc.\n\n"
1402
+ "**Tips:**\n"
1403
+ "- Use ADE20K validation images (provided examples) for meaningful class diversity\n"
1404
+ "- High mask agreement (>95%) indicates pruning preserved segmentation quality\n"
1405
+ "- Check class distribution to ensure dominant objects aren't misclassified\n"
1406
+ "- Structured pruning typically maintains better segmentation quality than unstructured"
1407
+ )
1408
+
1409
+ with gr.Column():
1410
+ metrics_sp = gr.Dataframe(label="πŸ“Š Comparison Metrics")
1411
+ class_sp = gr.Dataframe(label="πŸ“ˆ Class Distribution")
1412
+ overlay_slider_sp = gr.ImageSlider(label="Overlay Comparison", type="pil")
1413
+ mask_slider_sp = gr.ImageSlider(label="Mask Comparison", type="pil")
1414
+ sparsity_sp = gr.Dataframe(label="Layer sparsity (%)")
1415
+ downloads_sp = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
1416
+
1417
+ btn_sp.click(
1418
+ fn=run_pruned_segmentation,
1419
+ inputs=[
1420
+ img_sp,
1421
+ model_sp,
1422
+ method_sp,
1423
+ amount_sp,
1424
+ device_sp,
1425
+ channels_last_sp,
1426
+ compile_sp,
1427
+ amp_sp,
1428
+ export_ts_sp,
1429
+ export_onnx_sp,
1430
+ export_report_sp,
1431
+ gr.State(True),
1432
+ preset_sp,
1433
+ ],
1434
+ outputs=[
1435
+ metrics_sp,
1436
+ class_sp,
1437
+ overlay_slider_sp,
1438
+ mask_slider_sp,
1439
+ sparsity_sp,
1440
+ downloads_sp,
1441
+ ],
1442
+ )
1443
+
1444
+ # ---- SEGMENTATION QUANTIZATION TAB ----
1445
+ with gr.Tab("Quantization-Segmentation"):
1446
+ with gr.Row():
1447
+ with gr.Column():
1448
+ img_sq = gr.Image(label="Upload Image")
1449
+ model_sq = gr.Dropdown(seg_model_options, value=seg_model_options[0], label="Pretrained ADE20K Model")
1450
+ preset_sq = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
1451
+ q_type_sq = gr.Dropdown(["dynamic", "weight_only", "fp16"], value="dynamic", label="Quantization Type")
1452
+ device_sq = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
1453
+ channels_last_sq = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
1454
+ compile_sq = gr.Checkbox(label="Torch compile (PyTorch 2)")
1455
+ amp_sq = gr.Checkbox(label="Mixed precision (AMP)", value=True)
1456
+ export_ts_sq = gr.Checkbox(label="Export TorchScript")
1457
+ export_onnx_sq = gr.Checkbox(label="Export ONNX")
1458
+ export_report_sq = gr.Checkbox(label="Export JSON report", value=True)
1459
+ btn_sq = gr.Button("Run Segmentation Quantization")
1460
+ gr.Examples(examples=ade_examples, inputs=img_sq, label="ADE20K Samples")
1461
+ gr.Markdown(
1462
+ "### 🎨 Segmentation Quantization Guide\n\n"
1463
+ "**Quantization for Dense Prediction:**\n"
1464
+ "Semantic segmentation models are typically larger and slower than classifiers, making quantization especially valuable. "
1465
+ "This tab applies the same quantization techniques as classification but evaluates pixel-level prediction quality.\n\n"
1466
+ "**Available Models & Quantization:**\n"
1467
+ "- **SegFormer B0/B4**: Transformer-based; dynamic quantization helps with attention/MLP layers (CPU-only)\n"
1468
+ "- **DPT Large**: Vision-transformer backbone; benefits significantly from FP16 on GPU (~2x speedup)\n"
1469
+ "- **UPerNet ConvNeXt-Tiny**: CNN-based; FP16 quantization provides best GPU acceleration\n\n"
1470
+ "**Quantization Type Selection:**\n"
1471
+ "- **Dynamic/Weight-only**: ⚠️ Automatically uses CPU (PyTorch INT8 limitation). Best for: \n"
1472
+ " - Transformer-heavy models (SegFormer, DPT)\n"
1473
+ " - CPU-only deployment scenarios\n"
1474
+ " - Memory-constrained environments\n"
1475
+ "- **FP16**: Recommended for GPU deployment (CUDA, MPS). Provides:\n"
1476
+ " - ~2x inference speedup on modern GPUs\n"
1477
+ " - 50% memory reduction\n"
1478
+ " - Minimal segmentation quality loss (<1% mIoU typically)\n\n"
1479
+ "**Segmentation-Specific Metrics:**\n"
1480
+ "1. **Mask Agreement**: Critical metric for segmentation; >95% is good, >98% is excellent\n"
1481
+ "2. **Mean Confidence**: Should remain similar; large drops indicate quantization instability\n"
1482
+ "3. **Class Distribution**: Compare pixel percentages; mismatches show which objects are affected\n\n"
1483
+ "**Understanding the Outputs:**\n"
1484
+ "- **Overlay Slider**: Drag to compare original vs. quantized predictions on the actual image\n"
1485
+ "- **Mask Slider**: Raw segmentation masks for detailed comparison\n"
1486
+ "- **Class Distribution**: Top 25 classes help identify systematic errors (e.g., 'road' β†’ 'sidewalk' confusion)\n\n"
1487
+ "**Performance Expectations:**\n"
1488
+ "- **FP16 on CUDA**: Expect 1.5-2x speedup with <1% accuracy loss\n"
1489
+ "- **Dynamic on CPU**: Model size ↓ 75%, latency may increase (first-run overhead)\n"
1490
+ "- **Weight-only on CPU**: Model size ↓ 50%, latency similar to FP32\n\n"
1491
+ "**Export Options:**\n"
1492
+ "Files saved with `_seg` suffix: `quant_seg_model.onnx`, `quant_seg_state_dict.pth`, etc.\n\n"
1493
+ "**Best Practices:**\n"
1494
+ "βœ“ Use FP16 for GPU deployment (CUDA, MPS)\n"
1495
+ "βœ“ Use dynamic quantization for CPU-bound transformer models\n"
1496
+ "βœ“ Check mask agreement before deploying; <90% needs investigation\n"
1497
+ "βœ“ Validate on multiple images; some scenes may be more sensitive to quantization\n"
1498
+ "βœ— Avoid FP16 on CPU (performance penalty, not benefit)\n"
1499
+ "βœ— Don't expect large speedups from dynamic quantization on CNN-heavy models (most layers are Conv2d, not Linear)"
1500
+ )
1501
+
1502
+ with gr.Column():
1503
+ metrics_sq = gr.Dataframe(label="πŸ“Š Comparison Metrics")
1504
+ class_sq = gr.Dataframe(label="πŸ“ˆ Class Distribution")
1505
+ overlay_slider_sq = gr.ImageSlider(label="Overlay Comparison", type="pil")
1506
+ mask_slider_sq = gr.ImageSlider(label="Mask Comparison", type="pil")
1507
+ downloads_sq = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
1508
+
1509
+ btn_sq.click(
1510
+ fn=run_quantized_segmentation,
1511
+ inputs=[
1512
+ img_sq,
1513
+ model_sq,
1514
+ q_type_sq,
1515
+ device_sq,
1516
+ channels_last_sq,
1517
+ compile_sq,
1518
+ amp_sq,
1519
+ export_ts_sq,
1520
+ export_onnx_sq,
1521
+ export_report_sq,
1522
+ gr.State(True),
1523
+ preset_sq,
1524
+ ],
1525
+ outputs=[
1526
+ metrics_sq,
1527
+ class_sq,
1528
+ overlay_slider_sq,
1529
+ mask_slider_sq,
1530
+ downloads_sq,
1531
+ ],
1532
+ )
1533
+
1534
  return demo
1535
 
1536
 
examples/ADE_val_00000001.jpg ADDED
examples/ADE_val_00000002.jpg ADDED
requirements.txt CHANGED
@@ -2,6 +2,9 @@
2
  torch>=2.2.0
3
  torchvision>=0.17.0
4
  timm>=0.9.12
 
 
 
5
 
6
  # UI
7
  gradio>=4.19.2
 
2
  torch>=2.2.0
3
  torchvision>=0.17.0
4
  timm>=0.9.12
5
+ segmentation-models-pytorch>=0.3.3
6
+ huggingface-hub>=0.23.0
7
+ albumentations>=1.4.8
8
 
9
  # UI
10
  gradio>=4.19.2