shriarul5273 commited on
Commit
2797bac
·
1 Parent(s): 5d4f6a9

add YOLO12 models and update README, app.py, and requirements.txt

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. README.md +16 -5
  3. app.py +784 -0
  4. requirements.txt +1 -0
.gitignore CHANGED
@@ -2,3 +2,5 @@
2
  exports/
3
  __pycache__/
4
  *.pyc
 
 
 
2
  exports/
3
  __pycache__/
4
  *.pyc
5
+ *.pth
6
+ *pt
README.md CHANGED
@@ -11,16 +11,18 @@ pinned: false
11
 
12
  # Model Optimization Lab
13
 
14
- Interactive Gradio playground for comparing pruning and quantization on both ImageNet-classification and ADE20K-segmentation models. Upload any image and observe how latency, confidence, model size, and segmentation quality change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
15
 
16
  ## Features
17
  - **Classification Tasks**: Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
18
  - **Segmentation Tasks**: Pretrained ADE20K models (SegFormer B0/B4, DPT Large, UPerNet ConvNeXt-Tiny) with 150-class semantic segmentation.
19
- - **Pruning tabs**: Structured/unstructured pruning with configurable sparsity and comprehensive size/latency comparison for both classification and segmentation.
20
- - **Quantization tabs**: Dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels, available for both task types.
 
21
  - **Visual Comparisons**:
22
  - Classification: Automated metric tables and Top-5 bar charts to visualize confidence shifts.
23
  - Segmentation: Image sliders for overlay/mask comparisons, class distribution tables, and mask agreement metrics.
 
24
  - **Export Options**: TorchScript, ONNX, JSON reports, and state dictionaries for all optimization variants.
25
  - Lightweight CLI mode for quick experiments without launching the UI.
26
 
@@ -53,15 +55,16 @@ Interactive Gradio playground for comparing pruning and quantization on both Ima
53
  5. Open the local Gradio URL (printed in the terminal) in your browser.
54
 
55
  ## Using the App
56
- 1. **Upload an image** or pick one of the provided examples (ImageNet samples for classification, ADE20K validation images for segmentation).
57
  2. Choose the **Base Model** dropdown:
58
  - **Classification**: ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0
59
  - **Segmentation**: SegFormer B0/B4 (ADE20K 512x512), DPT Large (ADE20K), UPerNet ConvNeXt-Tiny (ADE20K)
 
60
  3. Pick a **Hardware Preset** or keep `custom`:
61
  - Edge CPU — CPU, channels-last off, dynamic quantization, 30% pruning.
62
  - Datacenter GPU — CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
63
  - Apple MPS — MPS, FP16 quantization, 20% pruning.
64
- 4. Select a tab (Pruning-Classification, Quantization-Classification, Pruning-Segmentation, or Quantization-Segmentation), configure options, then click **Run**.
65
 
66
  ### Pruning tab options (Classification & Segmentation)
67
  - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
@@ -77,6 +80,13 @@ Interactive Gradio playground for comparing pruning and quantization on both Ima
77
  - Classification: Comparison metrics, Top-5 bar chart, per-layer sparsity table, download list
78
  - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, per-layer sparsity table, download list
79
 
 
 
 
 
 
 
 
80
  ### Quantization tab options (Classification & Segmentation)
81
  - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
82
  - `Device`: `auto` picks CUDA → MPS → CPU; dynamic/weight-only runs force CPU execution for kernel support.
@@ -108,6 +118,7 @@ Interactive Gradio playground for comparing pruning and quantization on both Ima
108
  - Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization.
109
  - PyTorch default quantization backend may fall back to `qnnpack` on CPU. For x86 systems, set `torch.backends.quantized.engine = "fbgemm"` before quantization for best results.
110
  - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
 
111
 
112
  ## Extending the Lab
113
  - **Classification**: Swap in different architectures by changing the `timm.create_model` call in `app.py`.
 
11
 
12
  # Model Optimization Lab
13
 
14
+ Interactive Gradio playground for comparing pruning and quantization on ImageNet classification, ADE20K segmentation, and COCO detection models (TorchVision + YOLO12). Upload any image and observe how latency, confidence, model size, and segmentation/detection quality change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
15
 
16
  ## Features
17
  - **Classification Tasks**: Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
18
  - **Segmentation Tasks**: Pretrained ADE20K models (SegFormer B0/B4, DPT Large, UPerNet ConvNeXt-Tiny) with 150-class semantic segmentation.
19
+ - **Detection Tasks**: COCO-pretrained detectors (TorchVision Faster R-CNN/SSDlite) plus Ultralytics YOLO12 n/s/m/l/x.
20
+ - **Pruning tabs**: Structured/unstructured pruning with configurable sparsity and comprehensive size/latency comparison across tasks.
21
+ - **Quantization tabs**: Dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels, available for all tasks.
22
  - **Visual Comparisons**:
23
  - Classification: Automated metric tables and Top-5 bar charts to visualize confidence shifts.
24
  - Segmentation: Image sliders for overlay/mask comparisons, class distribution tables, and mask agreement metrics.
25
+ - Detection: Overlay sliders for pruned/quantized boxes and detection tables for quick inspection.
26
  - **Export Options**: TorchScript, ONNX, JSON reports, and state dictionaries for all optimization variants.
27
  - Lightweight CLI mode for quick experiments without launching the UI.
28
 
 
55
  5. Open the local Gradio URL (printed in the terminal) in your browser.
56
 
57
  ## Using the App
58
+ 1. **Upload an image** or pick one of the provided examples (ImageNet samples for classification, ADE20K validation images for segmentation; detection works with any RGB image).
59
  2. Choose the **Base Model** dropdown:
60
  - **Classification**: ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0
61
  - **Segmentation**: SegFormer B0/B4 (ADE20K 512x512), DPT Large (ADE20K), UPerNet ConvNeXt-Tiny (ADE20K)
62
+ - **Detection**: Faster R-CNN ResNet50 FPN (COCO), SSDlite320 MobileNetV3 (COCO), YOLO12 n/s/m/l/x (COCO via Ultralytics)
63
  3. Pick a **Hardware Preset** or keep `custom`:
64
  - Edge CPU — CPU, channels-last off, dynamic quantization, 30% pruning.
65
  - Datacenter GPU — CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
66
  - Apple MPS — MPS, FP16 quantization, 20% pruning.
67
+ 4. Select a tab (Pruning/Quantization for Classification, Detection, or Segmentation), configure options, then click **Run**.
68
 
69
  ### Pruning tab options (Classification & Segmentation)
70
  - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
 
80
  - Classification: Comparison metrics, Top-5 bar chart, per-layer sparsity table, download list
81
  - Segmentation: Comparison metrics, class distribution table, overlay/mask sliders, per-layer sparsity table, download list
82
 
83
+ ### Detection tab options (Pruning & Quantization)
84
+ - `Models`: TorchVision Faster R-CNN / SSDlite, plus Ultralytics YOLO12 n/s/m/l/x (auto-downloads checkpoints if missing).
85
+ - `Score Threshold`: Filters low-confidence boxes before metrics/overlays.
86
+ - `Pruning`: Structured recommended for detection heads; unstructured yields higher sparsity but fewer real speedups.
87
+ - `Quantization`: Dynamic/weight-only INT8 forces CPU for kernel support; FP16 targets CUDA/MPS. AMP + channels-last help on GPU.
88
+ - `Exports`: State dicts always saved. TorchScript/ONNX exports remain enabled for TorchVision detectors; YOLO12 exports are skipped (TorchScript/ONNX) but state dict is still written.
89
+
90
  ### Quantization tab options (Classification & Segmentation)
91
  - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
92
  - `Device`: `auto` picks CUDA → MPS → CPU; dynamic/weight-only runs force CPU execution for kernel support.
 
118
  - Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization.
119
  - PyTorch default quantization backend may fall back to `qnnpack` on CPU. For x86 systems, set `torch.backends.quantized.engine = "fbgemm"` before quantization for best results.
120
  - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
121
+ - Detection-specific: Dynamic/weight-only runs force CPU for kernel support; YOLO12 checkpoints auto-download but TorchScript/ONNX exports are disabled (state dicts still save).
122
 
123
  ## Extending the Lab
124
  - **Classification**: Swap in different architectures by changing the `timm.create_model` call in `app.py`.
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import io
3
  import json
4
  import os
@@ -17,6 +18,16 @@ import torch.nn.utils.prune as prune
17
  import segmentation_models_pytorch as smp
18
  from PIL import Image, ImageDraw, ImageFont
19
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
20
  try:
21
  import albumentations as A
22
  except ModuleNotFoundError: # pragma: no cover - optional dependency
@@ -131,6 +142,46 @@ ADE20K_CLASS_NAMES = [
131
  ]
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def add_image_label(img: Image.Image, label: str) -> Image.Image:
135
  """Add a text label at the top of an image."""
136
  img_array = np.array(img)
@@ -354,6 +405,304 @@ def get_class_labels(config: SegmentationModelConfig) -> list[str]:
354
  return labels[: config.classes]
355
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def run_segmentation_inference(
358
  model: nn.Module,
359
  image,
@@ -880,6 +1229,310 @@ def run_quantized(
880
  return metrics_df, chart_fig, downloads
881
 
882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  def run_pruned_segmentation(
884
  img,
885
  model_choice,
@@ -1199,6 +1852,7 @@ def create_demo():
1199
  device_opts.append("mps")
1200
  preset_opts = list(PRESETS.keys()) + ["custom"]
1201
  seg_model_options = [cfg.name for cfg in SEGMENTATION_MODEL_CONFIGS]
 
1202
 
1203
  with gr.Tabs():
1204
  # ---- PRUNING TAB ----
@@ -1345,6 +1999,136 @@ def create_demo():
1345
  outputs=[metrics_q, chart_q, downloads_q],
1346
  )
1347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1348
  # ---- SEGMENTATION PRUNING TAB ----
1349
  with gr.Tab("Pruning-Segmentation"):
1350
  with gr.Row():
 
1
  import argparse
2
+ import copy
3
  import io
4
  import json
5
  import os
 
18
  import segmentation_models_pytorch as smp
19
  from PIL import Image, ImageDraw, ImageFont
20
  from torchvision import transforms
21
+ from torchvision.models.detection import (
22
+ FasterRCNN_ResNet50_FPN_Weights,
23
+ SSDLite320_MobileNet_V3_Large_Weights,
24
+ fasterrcnn_resnet50_fpn,
25
+ ssdlite320_mobilenet_v3_large,
26
+ )
27
+ try:
28
+ from ultralytics import YOLO as UltralyticsYOLO
29
+ except ModuleNotFoundError: # pragma: no cover - optional dependency
30
+ UltralyticsYOLO = None
31
  try:
32
  import albumentations as A
33
  except ModuleNotFoundError: # pragma: no cover - optional dependency
 
142
  ]
143
 
144
 
145
+ # ---------------------------------------------
146
+ # Object Detection Registry / Defaults
147
+ # ---------------------------------------------
148
+ DETECTION_MODEL_CONFIGS = {
149
+ "Faster R-CNN ResNet50 FPN (COCO)": {
150
+ "builder": fasterrcnn_resnet50_fpn,
151
+ "weights": FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
152
+ "backend": "torchvision",
153
+ },
154
+ "SSDlite320 MobileNetV3 (COCO)": {
155
+ "builder": ssdlite320_mobilenet_v3_large,
156
+ "weights": SSDLite320_MobileNet_V3_Large_Weights.DEFAULT,
157
+ "backend": "torchvision",
158
+ },
159
+ }
160
+ COCO_CATEGORIES = list(FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta.get("categories", []))
161
+
162
+ # Optional YOLOv12 variants (Ultralytics) with size options: n/s/m/l/x
163
+ for _size in ("n", "s", "m", "l", "x"):
164
+ DETECTION_MODEL_CONFIGS[f"YOLO12-{_size} (COCO)"] = {
165
+ "backend": "ultralytics",
166
+ "weights": f"yolo12{_size}.pt",
167
+ "imgsz": 640,
168
+ "categories": COCO_CATEGORIES,
169
+ }
170
+ DETECTION_MODEL_OPTIONS = list(DETECTION_MODEL_CONFIGS.keys())
171
+
172
+ _DET_MODEL_CACHE: dict[str, nn.Module] = {}
173
+ _DET_TRANSFORM_CACHE: dict[str, object] = {}
174
+ _DET_LABELS_CACHE: dict[str, list[str]] = {}
175
+
176
+
177
+ def _require_ultralytics():
178
+ if UltralyticsYOLO is None:
179
+ raise RuntimeError(
180
+ "The 'ultralytics' package is required for YOLO12 models. "
181
+ "Install it with `pip install ultralytics` to enable these options."
182
+ )
183
+
184
+
185
  def add_image_label(img: Image.Image, label: str) -> Image.Image:
186
  """Add a text label at the top of an image."""
187
  img_array = np.array(img)
 
405
  return labels[: config.classes]
406
 
407
 
408
+ # ---------------------------------------------
409
+ # Object Detection Utilities
410
+ # ---------------------------------------------
411
+ def get_detection_config(model_name: str) -> dict:
412
+ if model_name not in DETECTION_MODEL_CONFIGS:
413
+ raise ValueError(f"Unknown detection model: {model_name}")
414
+ return dict(DETECTION_MODEL_CONFIGS[model_name])
415
+
416
+
417
+ def get_detection_labels(model_name: str) -> list[str]:
418
+ if model_name in _DET_LABELS_CACHE:
419
+ return _DET_LABELS_CACHE[model_name]
420
+ cfg = get_detection_config(model_name)
421
+ categories = cfg.get("categories")
422
+ if categories:
423
+ labels = categories
424
+ else:
425
+ weights = cfg.get("weights")
426
+ labels = weights.meta.get("categories", []) if weights else []
427
+ _DET_LABELS_CACHE[model_name] = list(labels)
428
+ return _DET_LABELS_CACHE[model_name]
429
+
430
+
431
+ def get_detection_transform(model_name: str):
432
+ if model_name in _DET_TRANSFORM_CACHE:
433
+ return _DET_TRANSFORM_CACHE[model_name]
434
+ cfg = get_detection_config(model_name)
435
+ backend = cfg.get("backend", "torchvision")
436
+ if backend == "ultralytics":
437
+ transform = lambda img: img # Ultralytics handles preprocessing internally
438
+ else:
439
+ weights = cfg.get("weights")
440
+ transform = weights.transforms() if weights else transforms.Compose([transforms.ToTensor()])
441
+ _DET_TRANSFORM_CACHE[model_name] = transform
442
+ return transform
443
+
444
+
445
+ def get_detection_model(model_name: str) -> nn.Module:
446
+ if model_name not in _DET_MODEL_CACHE:
447
+ cfg = get_detection_config(model_name)
448
+ backend = cfg.get("backend", "torchvision")
449
+ if backend == "ultralytics":
450
+ _require_ultralytics()
451
+ weights = cfg.get("weights")
452
+ try:
453
+ model = UltralyticsYOLO(weights)
454
+ except Exception as exc:
455
+ raise RuntimeError(
456
+ f"Failed to load YOLO12 weights '{weights}'. Download or place the checkpoint locally first."
457
+ ) from exc
458
+ if hasattr(model, "model"):
459
+ model.model.eval()
460
+ else:
461
+ weights = cfg.get("weights")
462
+ try:
463
+ model = cfg["builder"](weights=weights)
464
+ except Exception as exc:
465
+ print(f"Warning: detection weights unavailable ({exc}); using random init for {model_name}")
466
+ model = cfg["builder"](weights=None)
467
+ model.eval()
468
+ _DET_MODEL_CACHE[model_name] = model
469
+ return _DET_MODEL_CACHE[model_name]
470
+
471
+
472
+ def clone_detection_model(model_name: str) -> nn.Module:
473
+ base = get_detection_model(model_name)
474
+ cfg = get_detection_config(model_name)
475
+ backend = cfg.get("backend", "torchvision")
476
+ if backend == "ultralytics":
477
+ _require_ultralytics()
478
+ fresh = copy.deepcopy(base)
479
+ if hasattr(fresh, "model") and isinstance(fresh.model, nn.Module):
480
+ fresh.model.eval()
481
+ return fresh
482
+
483
+ fresh = cfg["builder"](weights=None)
484
+ fresh.load_state_dict(base.state_dict())
485
+ fresh.eval()
486
+ return fresh
487
+
488
+
489
+ def prepare_detection_input(image, transform_fn):
490
+ if image is None:
491
+ raise ValueError("No image provided")
492
+ if not isinstance(image, Image.Image):
493
+ if isinstance(image, np.ndarray) and image.dtype != np.uint8:
494
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
495
+ image = Image.fromarray(np.array(image).astype("uint8"))
496
+ image_rgb = image.convert("RGB")
497
+ tensor = transform_fn(image_rgb)
498
+ if tensor.ndim == 3:
499
+ tensor = tensor
500
+ else:
501
+ tensor = torch.as_tensor(tensor)
502
+ return tensor, image_rgb
503
+
504
+
505
+ def draw_detections(image: Image.Image, detections: list[dict], max_dets: int = 30) -> Image.Image:
506
+ canvas = image.copy()
507
+ draw = ImageDraw.Draw(canvas)
508
+ colors = _SEG_BASE_PALETTE # reuse palette for variety
509
+ for idx, det in enumerate(detections[:max_dets]):
510
+ box = det["box"]
511
+ color = tuple(int(c) for c in colors[idx % len(colors)])
512
+ draw.rectangle(box, outline=color, width=3)
513
+ label = f"{det['label']} {det['score']:.2f}"
514
+ draw.text((box[0] + 4, box[1] + 4), label, fill=color)
515
+ return canvas
516
+
517
+
518
+ def run_detection_inference(
519
+ model: nn.Module,
520
+ image,
521
+ device: torch.device,
522
+ transform_fn,
523
+ channels_last: bool,
524
+ warmup: bool,
525
+ use_amp: bool,
526
+ score_thresh: float = 0.25,
527
+ backend: str = "torchvision",
528
+ imgsz: int | None = None,
529
+ ):
530
+ if backend == "ultralytics":
531
+ if image is None:
532
+ raise ValueError("No image provided")
533
+ if not isinstance(image, Image.Image):
534
+ if isinstance(image, np.ndarray) and image.dtype != np.uint8:
535
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
536
+ image = Image.fromarray(np.array(image).astype("uint8"))
537
+ image_rgb = image.convert("RGB")
538
+
539
+ device_arg = str(device) if isinstance(device, torch.device) else device
540
+ half = use_amp and isinstance(device, torch.device) and device.type == "cuda"
541
+ if hasattr(model, "model") and isinstance(model.model, nn.Module):
542
+ model.model.to(device)
543
+
544
+ if warmup:
545
+ with torch.no_grad():
546
+ model.predict(image_rgb, imgsz=imgsz, device=device_arg, verbose=False, half=half)
547
+
548
+ start = time.time()
549
+ with torch.no_grad():
550
+ results = model.predict(image_rgb, imgsz=imgsz, device=device_arg, verbose=False, half=half)
551
+ latency = (time.time() - start) * 1000
552
+
553
+ dets: list[dict] = []
554
+ if results:
555
+ res = results[0]
556
+ boxes = getattr(res, "boxes", None)
557
+ if boxes is not None:
558
+ xyxy = boxes.xyxy.detach().cpu().numpy()
559
+ confs = boxes.conf.detach().cpu().numpy()
560
+ labels = boxes.cls.detach().cpu().numpy()
561
+ for box, score, label_idx in zip(xyxy, confs, labels):
562
+ if score < score_thresh:
563
+ continue
564
+ dets.append(
565
+ {
566
+ "label": str(int(label_idx)),
567
+ "score": float(score),
568
+ "box": [float(x) for x in box],
569
+ }
570
+ )
571
+
572
+ return {"detections": dets, "latency": latency, "image": image_rgb}
573
+
574
+ tensor, image_rgb = prepare_detection_input(image, transform_fn)
575
+ model = model.to(device)
576
+
577
+ batch_tensor = tensor.to(device)
578
+ if channels_last and device.type == "cuda" and batch_tensor.dim() == 4:
579
+ batch_tensor = batch_tensor.to(memory_format=torch.channels_last)
580
+ elif channels_last and device.type == "cuda":
581
+ # Channels-last requires NCHW (4D) input; detection tensors are 3D.
582
+ pass
583
+
584
+ if next(model.parameters()).dtype == torch.float16:
585
+ batch_tensor = batch_tensor.half()
586
+
587
+ inputs = [batch_tensor]
588
+
589
+ if warmup:
590
+ with torch.no_grad():
591
+ model(inputs)
592
+
593
+ amp_ctx = torch.cuda.amp.autocast(enabled=use_amp and device.type == "cuda")
594
+ start = time.time()
595
+ with torch.no_grad(), amp_ctx:
596
+ outputs = model(inputs)
597
+ latency = (time.time() - start) * 1000
598
+
599
+ out = outputs[0]
600
+ boxes = out["boxes"].detach().cpu().numpy()
601
+ scores = out["scores"].detach().cpu().numpy()
602
+ labels = out["labels"].detach().cpu().numpy()
603
+
604
+ dets = []
605
+ for box, score, label_idx in zip(boxes, scores, labels):
606
+ if score < score_thresh:
607
+ continue
608
+ dets.append(
609
+ {
610
+ "label": str(label_idx),
611
+ "score": float(score),
612
+ "box": [float(x) for x in box],
613
+ }
614
+ )
615
+
616
+ return {
617
+ "detections": dets,
618
+ "latency": latency,
619
+ "image": image_rgb,
620
+ }
621
+
622
+
623
+ def attach_detection_labels(detections: list[dict], label_names: list[str]) -> list[dict]:
624
+ labeled = []
625
+ for det in detections:
626
+ idx = int(det["label"])
627
+ name = label_names[idx] if idx < len(label_names) else f"Class {idx}"
628
+ labeled.append({**det, "label": name})
629
+ return labeled
630
+
631
+
632
+ def get_detection_state_module(model, backend: str):
633
+ if backend == "ultralytics" and hasattr(model, "model"):
634
+ return model.model
635
+ return model
636
+
637
+
638
+ def build_detection_metrics(
639
+ original_result: dict,
640
+ optimized_result: dict,
641
+ size_original: float,
642
+ size_optimized: float,
643
+ optimized_label: str,
644
+ score_thresh: float,
645
+ ):
646
+ orig_dets = original_result["detections"]
647
+ opt_dets = optimized_result["detections"]
648
+ mean_score_orig = float(np.mean([d["score"] for d in orig_dets])) if orig_dets else 0.0
649
+ mean_score_opt = float(np.mean([d["score"] for d in opt_dets])) if opt_dets else 0.0
650
+
651
+ metrics_df = pd.DataFrame(
652
+ {
653
+ "Metric": [
654
+ "Latency (ms)",
655
+ f"Detections (score>={score_thresh})",
656
+ "Mean Score",
657
+ "Model Size (MB)",
658
+ ],
659
+ "Original Model": [
660
+ f"{original_result['latency']:.2f}",
661
+ str(len(orig_dets)),
662
+ f"{mean_score_orig:.3f}",
663
+ f"{size_original:.2f}",
664
+ ],
665
+ optimized_label: [
666
+ f"{optimized_result['latency']:.2f}",
667
+ str(len(opt_dets)),
668
+ f"{mean_score_opt:.3f}",
669
+ f"{size_optimized:.2f}",
670
+ ],
671
+ }
672
+ )
673
+ return metrics_df
674
+
675
+
676
+ def build_detection_comparison_df(
677
+ orig_dets: list[dict],
678
+ opt_dets: list[dict],
679
+ optimized_label: str,
680
+ max_rows: int = 50,
681
+ ) -> pd.DataFrame:
682
+ rows = []
683
+ for det in orig_dets:
684
+ rows.append(
685
+ {
686
+ "Model": "Original",
687
+ "Class": det["label"],
688
+ "Score": round(det["score"], 3),
689
+ "Box [x1,y1,x2,y2]": [round(x, 1) for x in det["box"]],
690
+ }
691
+ )
692
+ for det in opt_dets:
693
+ rows.append(
694
+ {
695
+ "Model": optimized_label,
696
+ "Class": det["label"],
697
+ "Score": round(det["score"], 3),
698
+ "Box [x1,y1,x2,y2]": [round(x, 1) for x in det["box"]],
699
+ }
700
+ )
701
+ if max_rows and len(rows) > max_rows:
702
+ rows = rows[:max_rows]
703
+ return pd.DataFrame(rows)
704
+
705
+
706
  def run_segmentation_inference(
707
  model: nn.Module,
708
  image,
 
1229
  return metrics_df, chart_fig, downloads
1230
 
1231
 
1232
+ def run_pruned_detection(
1233
+ img,
1234
+ model_choice,
1235
+ method,
1236
+ amount,
1237
+ device_choice="auto",
1238
+ channels_last=False,
1239
+ use_compile=False,
1240
+ use_amp=False,
1241
+ export_ts=False,
1242
+ export_onnx=False,
1243
+ export_report=False,
1244
+ export_state=True,
1245
+ preset=None,
1246
+ score_thresh=0.25,
1247
+ ):
1248
+ print("\n=== RUN DETECTION PRUNED CALLED ===")
1249
+ if img is None:
1250
+ print("ERROR: Image is None")
1251
+ empty_metrics = pd.DataFrame({"Metric": ["Error"], "Original Model": ["No image"], "Pruned Model": [""]})
1252
+ return empty_metrics, None, pd.DataFrame(), []
1253
+
1254
+ if preset in PRESETS:
1255
+ preset_cfg = PRESETS[preset]
1256
+ device_choice = preset_cfg["device"]
1257
+ channels_last = preset_cfg["channels_last"]
1258
+ use_compile = preset_cfg["compile"]
1259
+ use_amp = preset_cfg.get("amp", use_amp)
1260
+ amount = preset_cfg.get("prune_amount", amount)
1261
+
1262
+ device = select_device(device_choice)
1263
+ cfg = get_detection_config(model_choice)
1264
+ backend = cfg.get("backend", "torchvision")
1265
+ imgsz = cfg.get("imgsz")
1266
+ labels = get_detection_labels(model_choice)
1267
+ transform_fn = get_detection_transform(model_choice)
1268
+
1269
+ base_model = get_detection_model(model_choice)
1270
+ original_result = run_detection_inference(
1271
+ base_model,
1272
+ img,
1273
+ device,
1274
+ transform_fn,
1275
+ channels_last=channels_last,
1276
+ warmup=True,
1277
+ use_amp=use_amp,
1278
+ score_thresh=score_thresh,
1279
+ backend=backend,
1280
+ imgsz=imgsz,
1281
+ )
1282
+ original_result["detections"] = attach_detection_labels(original_result["detections"], labels)
1283
+
1284
+ fresh_model = clone_detection_model(model_choice)
1285
+ pruned_module = apply_pruning(get_detection_state_module(fresh_model, backend), amount=float(amount), method=method)
1286
+ pruned_module = maybe_compile(pruned_module, use_compile)
1287
+ if backend == "ultralytics" and hasattr(fresh_model, "model"):
1288
+ fresh_model.model = pruned_module
1289
+ pruned_model = fresh_model
1290
+ else:
1291
+ pruned_model = pruned_module
1292
+ pruned_result = run_detection_inference(
1293
+ pruned_model,
1294
+ img,
1295
+ device,
1296
+ transform_fn,
1297
+ channels_last=channels_last,
1298
+ warmup=True,
1299
+ use_amp=use_amp,
1300
+ score_thresh=score_thresh,
1301
+ backend=backend,
1302
+ imgsz=imgsz,
1303
+ )
1304
+ pruned_result["detections"] = attach_detection_labels(pruned_result["detections"], labels)
1305
+
1306
+ size_orig = get_state_dict_size_mb(get_detection_state_module(base_model, backend))
1307
+ size_pruned = get_state_dict_size_mb(get_detection_state_module(pruned_model, backend))
1308
+
1309
+ metrics_df = build_detection_metrics(
1310
+ original_result, pruned_result, size_orig, size_pruned, "Pruned Model", score_thresh
1311
+ )
1312
+ det_df = build_detection_comparison_df(original_result["detections"], pruned_result["detections"], "Pruned")
1313
+ overlay_slider_value = (
1314
+ draw_detections(original_result["image"], original_result["detections"]),
1315
+ draw_detections(pruned_result["image"], pruned_result["detections"]),
1316
+ )
1317
+
1318
+ downloads: list[str] = []
1319
+ export_dir = Path("exports")
1320
+ export_dir.mkdir(exist_ok=True)
1321
+ trace_inputs = None
1322
+
1323
+ if backend != "ultralytics":
1324
+ sample_tensor, _ = prepare_detection_input(img, transform_fn)
1325
+ sample_batch = [sample_tensor]
1326
+ trace_inputs = (sample_batch,)
1327
+ else:
1328
+ if export_ts or export_onnx:
1329
+ print("TorchScript/ONNX export is not enabled for YOLO12 models in this app.")
1330
+ export_ts = False
1331
+ export_onnx = False
1332
+
1333
+ if export_report:
1334
+ report_path = export_dir / "pruned_det_report.json"
1335
+ report = {
1336
+ "model": model_choice,
1337
+ "pruning": {"method": method, "amount": float(amount)},
1338
+ "score_threshold": score_thresh,
1339
+ "metrics": metrics_df.to_dict(),
1340
+ "detections": {
1341
+ "original": original_result["detections"],
1342
+ "pruned": pruned_result["detections"],
1343
+ },
1344
+ }
1345
+ report_path.write_text(json.dumps(report, indent=2))
1346
+ downloads.append(str(report_path))
1347
+
1348
+ if export_state:
1349
+ state_path = export_dir / "pruned_det_state_dict.pth"
1350
+ torch.save(get_detection_state_module(pruned_model, backend).state_dict(), state_path)
1351
+ downloads.append(str(state_path))
1352
+
1353
+ if export_ts and trace_inputs is not None:
1354
+ ts_path = export_dir / "pruned_det_model.ts"
1355
+ try:
1356
+ scripted = torch.jit.trace(pruned_model.cpu(), trace_inputs)
1357
+ scripted.save(ts_path)
1358
+ downloads.append(str(ts_path))
1359
+ except Exception as exc: # pragma: no cover - export best effort
1360
+ print(f"TorchScript export failed: {exc}")
1361
+
1362
+ if export_onnx and trace_inputs is not None:
1363
+ onnx_path = export_dir / "pruned_det_model.onnx"
1364
+ try:
1365
+ torch.onnx.export(
1366
+ pruned_model.cpu(),
1367
+ trace_inputs,
1368
+ onnx_path,
1369
+ input_names=["images"],
1370
+ output_names=["detections"],
1371
+ opset_version=13,
1372
+ dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"}},
1373
+ )
1374
+ downloads.append(str(onnx_path))
1375
+ except Exception as exc: # pragma: no cover - export best effort
1376
+ print(f"ONNX export failed: {exc}")
1377
+
1378
+ print("=== RUN DETECTION PRUNED COMPLETE ===")
1379
+ return metrics_df, overlay_slider_value, det_df, downloads
1380
+
1381
+
1382
+ def run_quantized_detection(
1383
+ img,
1384
+ model_choice,
1385
+ q_type,
1386
+ device_choice="auto",
1387
+ channels_last=False,
1388
+ use_compile=False,
1389
+ use_amp=False,
1390
+ export_ts=False,
1391
+ export_onnx=False,
1392
+ export_report=False,
1393
+ export_state=True,
1394
+ preset=None,
1395
+ score_thresh=0.25,
1396
+ ):
1397
+ print("\n=== RUN DETECTION QUANTIZED CALLED ===")
1398
+ if img is None:
1399
+ print("ERROR: Image is None")
1400
+ empty_metrics = pd.DataFrame({"Metric": ["Error"], "Original Model": ["No image"], "Quantized Model": [""]})
1401
+ return empty_metrics, None, pd.DataFrame(), []
1402
+
1403
+ if preset in PRESETS:
1404
+ preset_cfg = PRESETS[preset]
1405
+ device_choice = preset_cfg["device"]
1406
+ channels_last = preset_cfg["channels_last"]
1407
+ use_compile = preset_cfg["compile"]
1408
+ use_amp = preset_cfg.get("amp", use_amp)
1409
+ q_type = preset_cfg.get("quant", q_type)
1410
+
1411
+ device = select_device(device_choice)
1412
+ if q_type in {"dynamic", "weight_only"} and device.type != "cpu":
1413
+ print("Dynamic/weight-only quantization uses CPU kernels; switching device to CPU.")
1414
+ device = torch.device("cpu")
1415
+ channels_last = False
1416
+ use_amp = False
1417
+ cfg = get_detection_config(model_choice)
1418
+ backend = cfg.get("backend", "torchvision")
1419
+ imgsz = cfg.get("imgsz")
1420
+
1421
+ labels = get_detection_labels(model_choice)
1422
+ transform_fn = get_detection_transform(model_choice)
1423
+ base_model = get_detection_model(model_choice)
1424
+
1425
+ original_result = run_detection_inference(
1426
+ base_model,
1427
+ img,
1428
+ device,
1429
+ transform_fn,
1430
+ channels_last=channels_last,
1431
+ warmup=True,
1432
+ use_amp=use_amp,
1433
+ score_thresh=score_thresh,
1434
+ backend=backend,
1435
+ imgsz=imgsz,
1436
+ )
1437
+ original_result["detections"] = attach_detection_labels(original_result["detections"], labels)
1438
+
1439
+ fresh_model = clone_detection_model(model_choice)
1440
+ quant_module = apply_quantization(get_detection_state_module(fresh_model, backend), q_type)
1441
+ quant_module = maybe_compile(quant_module, use_compile)
1442
+ if backend == "ultralytics" and hasattr(fresh_model, "model"):
1443
+ fresh_model.model = quant_module
1444
+ quant_model = fresh_model
1445
+ else:
1446
+ quant_model = quant_module
1447
+ quant_result = run_detection_inference(
1448
+ quant_model,
1449
+ img,
1450
+ device,
1451
+ transform_fn,
1452
+ channels_last=channels_last,
1453
+ warmup=True,
1454
+ use_amp=use_amp,
1455
+ score_thresh=score_thresh,
1456
+ backend=backend,
1457
+ imgsz=imgsz,
1458
+ )
1459
+ quant_result["detections"] = attach_detection_labels(quant_result["detections"], labels)
1460
+
1461
+ size_orig = get_state_dict_size_mb(get_detection_state_module(base_model, backend))
1462
+ size_quant = get_state_dict_size_mb(get_detection_state_module(quant_model, backend))
1463
+ metrics_df = build_detection_metrics(
1464
+ original_result, quant_result, size_orig, size_quant, "Quantized Model", score_thresh
1465
+ )
1466
+ det_df = build_detection_comparison_df(original_result["detections"], quant_result["detections"], "Quantized")
1467
+ overlay_slider_value = (
1468
+ draw_detections(original_result["image"], original_result["detections"]),
1469
+ draw_detections(quant_result["image"], quant_result["detections"]),
1470
+ )
1471
+
1472
+ downloads: list[str] = []
1473
+ export_dir = Path("exports")
1474
+ export_dir.mkdir(exist_ok=True)
1475
+ trace_inputs = None
1476
+
1477
+ if backend != "ultralytics":
1478
+ sample_tensor, _ = prepare_detection_input(img, transform_fn)
1479
+ sample_batch = [sample_tensor]
1480
+ trace_inputs = (sample_batch,)
1481
+ else:
1482
+ if export_ts or export_onnx:
1483
+ print("TorchScript/ONNX export is not enabled for YOLO12 models in this app.")
1484
+ export_ts = False
1485
+ export_onnx = False
1486
+
1487
+ if export_report:
1488
+ report_path = export_dir / "quant_det_report.json"
1489
+ report = {
1490
+ "model": model_choice,
1491
+ "quantization": q_type,
1492
+ "score_threshold": score_thresh,
1493
+ "metrics": metrics_df.to_dict(),
1494
+ "detections": {
1495
+ "original": original_result["detections"],
1496
+ "quantized": quant_result["detections"],
1497
+ },
1498
+ }
1499
+ report_path.write_text(json.dumps(report, indent=2))
1500
+ downloads.append(str(report_path))
1501
+
1502
+ if export_state:
1503
+ state_path = export_dir / "quant_det_state_dict.pth"
1504
+ torch.save(get_detection_state_module(quant_model, backend).state_dict(), state_path)
1505
+ downloads.append(str(state_path))
1506
+
1507
+ if export_ts and trace_inputs is not None:
1508
+ ts_path = export_dir / "quant_det_model.ts"
1509
+ try:
1510
+ scripted = torch.jit.trace(quant_model.cpu(), trace_inputs)
1511
+ scripted.save(ts_path)
1512
+ downloads.append(str(ts_path))
1513
+ except Exception as exc: # pragma: no cover - export best effort
1514
+ print(f"TorchScript export failed: {exc}")
1515
+
1516
+ if export_onnx and trace_inputs is not None:
1517
+ onnx_path = export_dir / "quant_det_model.onnx"
1518
+ try:
1519
+ torch.onnx.export(
1520
+ quant_model.cpu(),
1521
+ trace_inputs,
1522
+ onnx_path,
1523
+ input_names=["images"],
1524
+ output_names=["detections"],
1525
+ opset_version=13,
1526
+ dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"}},
1527
+ )
1528
+ downloads.append(str(onnx_path))
1529
+ except Exception as exc: # pragma: no cover - export best effort
1530
+ print(f"ONNX export failed: {exc}")
1531
+
1532
+ print("=== RUN DETECTION QUANTIZED COMPLETE ===")
1533
+ return metrics_df, overlay_slider_value, det_df, downloads
1534
+
1535
+
1536
  def run_pruned_segmentation(
1537
  img,
1538
  model_choice,
 
1852
  device_opts.append("mps")
1853
  preset_opts = list(PRESETS.keys()) + ["custom"]
1854
  seg_model_options = [cfg.name for cfg in SEGMENTATION_MODEL_CONFIGS]
1855
+ det_model_options = DETECTION_MODEL_OPTIONS.copy()
1856
 
1857
  with gr.Tabs():
1858
  # ---- PRUNING TAB ----
 
1999
  outputs=[metrics_q, chart_q, downloads_q],
2000
  )
2001
 
2002
+ # ---- DETECTION PRUNING TAB ----
2003
+ with gr.Tab("Pruning-Detection"):
2004
+ with gr.Row():
2005
+ with gr.Column():
2006
+ img_dp = gr.Image(label="Upload Image")
2007
+ model_dp = gr.Dropdown(det_model_options, value=det_model_options[0], label="Object Detector (COCO)")
2008
+ preset_dp = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
2009
+ method_dp = gr.Dropdown(["unstructured", "structured"], value="structured", label="Pruning Method")
2010
+ amount_dp = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.3, label="Pruning Amount")
2011
+ score_dp = gr.Slider(minimum=0.05, maximum=0.9, step=0.05, value=0.25, label="Score Threshold")
2012
+ device_dp = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
2013
+ channels_last_dp = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
2014
+ amp_dp = gr.Checkbox(label="Mixed precision (AMP)", value=True)
2015
+ compile_dp = gr.Checkbox(label="Torch compile (PyTorch 2)")
2016
+ export_ts_dp = gr.Checkbox(label="Export TorchScript")
2017
+ export_onnx_dp = gr.Checkbox(label="Export ONNX")
2018
+ export_report_dp = gr.Checkbox(label="Export JSON report", value=True)
2019
+ btn_dp = gr.Button("Run Detection Pruning")
2020
+ gr.Examples(examples=examples, inputs=img_dp)
2021
+ gr.Markdown(
2022
+ "### 🦾 Detection Pruning Guide\n\n"
2023
+ "**Models:**\n"
2024
+ "- TorchVision: Faster R-CNN ResNet50 FPN, SSDlite320 MobileNetV3 (COCO pretrained)\n"
2025
+ "- Ultralytics YOLO12: sizes n/s/m/l/x (COCO, auto-downloaded if missing)\n\n"
2026
+ "**Core Options:**\n"
2027
+ "- *Hardware Preset*: Same CPU/GPU defaults as classification; channels-last only applies on CUDA.\n"
2028
+ "- *Pruning Method*: Structured is safest for detection heads; unstructured yields higher sparsity but rarely speeds up NMS.\n"
2029
+ "- *Score Threshold*: Filters low-confidence boxes before metrics/overlays.\n"
2030
+ "- *AMP / Torch Compile*: Only useful on GPU; compile adds startup cost but can speed up steady-state.\n"
2031
+ "- *YOLO12 exports*: TorchScript/ONNX disabled here; state_dict still saved for the underlying torch model.\n\n"
2032
+ "**Reading Results:**\n"
2033
+ "- Metrics: latency, box count above threshold, mean score, model size.\n"
2034
+ "- Overlay slider: drag to compare original vs pruned detections.\n"
2035
+ "- Detections table: flattened list of boxes for quick scanning."
2036
+ )
2037
+
2038
+ with gr.Column():
2039
+ metrics_dp = gr.Dataframe(label="📊 Detection Metrics", headers=["Metric", "Original Model", "Pruned Model"])
2040
+ overlay_dp = gr.ImageSlider(label="Overlay Comparison", type="pil")
2041
+ dets_dp = gr.Dataframe(label="Detections (Original vs Pruned)")
2042
+ downloads_dp = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
2043
+
2044
+ btn_dp.click(
2045
+ fn=run_pruned_detection,
2046
+ inputs=[
2047
+ img_dp,
2048
+ model_dp,
2049
+ method_dp,
2050
+ amount_dp,
2051
+ device_dp,
2052
+ channels_last_dp,
2053
+ compile_dp,
2054
+ amp_dp,
2055
+ export_ts_dp,
2056
+ export_onnx_dp,
2057
+ export_report_dp,
2058
+ gr.State(True),
2059
+ preset_dp,
2060
+ score_dp,
2061
+ ],
2062
+ outputs=[
2063
+ metrics_dp,
2064
+ overlay_dp,
2065
+ dets_dp,
2066
+ downloads_dp,
2067
+ ],
2068
+ )
2069
+
2070
+ # ---- DETECTION QUANTIZATION TAB ----
2071
+ with gr.Tab("Quantization-Detection"):
2072
+ with gr.Row():
2073
+ with gr.Column():
2074
+ img_dq = gr.Image(label="Upload Image")
2075
+ model_dq = gr.Dropdown(det_model_options, value=det_model_options[0], label="Object Detector (COCO)")
2076
+ preset_dq = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
2077
+ q_type_dq = gr.Dropdown(["dynamic", "weight_only", "fp16"], value="dynamic", label="Quantization Type")
2078
+ score_dq = gr.Slider(minimum=0.05, maximum=0.9, step=0.05, value=0.25, label="Score Threshold")
2079
+ device_dq = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
2080
+ channels_last_dq = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
2081
+ amp_dq = gr.Checkbox(label="Mixed precision (AMP)", value=True)
2082
+ compile_dq = gr.Checkbox(label="Torch compile (PyTorch 2)")
2083
+ export_ts_dq = gr.Checkbox(label="Export TorchScript")
2084
+ export_onnx_dq = gr.Checkbox(label="Export ONNX")
2085
+ export_report_dq = gr.Checkbox(label="Export JSON report", value=True)
2086
+ btn_dq = gr.Button("Run Detection Quantization")
2087
+ gr.Examples(examples=examples, inputs=img_dq)
2088
+ gr.Markdown(
2089
+ "### ⚡ Detection Quantization Guide\n\n"
2090
+ "**Models:** TorchVision detectors and YOLO12 n/s/m/l/x (Ultralytics). YOLO12 uses its internal preprocessing; other models use TorchVision transforms.\n\n"
2091
+ "**Quantization Modes:**\n"
2092
+ "- *Dynamic / Weight-only*: INT8 linear layers on CPU. UI auto-switches to CPU even if GPU selected (PyTorch limitation).\n"
2093
+ "- *FP16*: Half precision for CUDA/MPS; keeps CPU in FP32. Pair with AMP + channels-last for best GPU speed.\n\n"
2094
+ "**Tips:**\n"
2095
+ "- Score threshold trims noisy boxes before metrics/overlays.\n"
2096
+ "- TorchScript/ONNX exports are skipped for YOLO12; state_dict still saved. TorchVision exports remain enabled.\n"
2097
+ "- For fastest runs, keep AMP + channels-last on CUDA; disable compile if you only run a single image.\n\n"
2098
+ "**Outputs:** Metrics table, overlay slider, detections table, and exports in `exports/` with `_det` suffix."
2099
+ )
2100
+
2101
+ with gr.Column():
2102
+ metrics_dq = gr.Dataframe(label="📊 Detection Metrics", headers=["Metric", "Original Model", "Quantized Model"])
2103
+ overlay_dq = gr.ImageSlider(label="Overlay Comparison", type="pil")
2104
+ dets_dq = gr.Dataframe(label="Detections (Original vs Quantized)")
2105
+ downloads_dq = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
2106
+
2107
+ btn_dq.click(
2108
+ fn=run_quantized_detection,
2109
+ inputs=[
2110
+ img_dq,
2111
+ model_dq,
2112
+ q_type_dq,
2113
+ device_dq,
2114
+ channels_last_dq,
2115
+ compile_dq,
2116
+ amp_dq,
2117
+ export_ts_dq,
2118
+ export_onnx_dq,
2119
+ export_report_dq,
2120
+ gr.State(True),
2121
+ preset_dq,
2122
+ score_dq,
2123
+ ],
2124
+ outputs=[
2125
+ metrics_dq,
2126
+ overlay_dq,
2127
+ dets_dq,
2128
+ downloads_dq,
2129
+ ],
2130
+ )
2131
+
2132
  # ---- SEGMENTATION PRUNING TAB ----
2133
  with gr.Tab("Pruning-Segmentation"):
2134
  with gr.Row():
requirements.txt CHANGED
@@ -5,6 +5,7 @@ timm>=0.9.12
5
  segmentation-models-pytorch>=0.3.3
6
  huggingface-hub>=0.23.0
7
  albumentations>=1.4.8
 
8
 
9
  # UI
10
  gradio>=4.19.2
 
5
  segmentation-models-pytorch>=0.3.3
6
  huggingface-hub>=0.23.0
7
  albumentations>=1.4.8
8
+ ultralytics>=8.3.0 # YOLO12 detection backends
9
 
10
  # UI
11
  gradio>=4.19.2