Spaces:
Sleeping
Sleeping
Commit
Β·
46b3a11
1
Parent(s):
7972153
Refactor app.py for model optimization: add pruning and quantization options.
Browse files- .github/workflows/huggingface.yml +25 -0
- .gitignore +4 -0
- README.md +106 -1
- app.py +819 -180
.github/workflows/huggingface.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face hub
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [main]
|
| 5 |
+
|
| 6 |
+
# to run this workflow manually from the Actions tab
|
| 7 |
+
workflow_dispatch:
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
sync-to-hub:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
steps:
|
| 13 |
+
- uses: actions/checkout@v5
|
| 14 |
+
with:
|
| 15 |
+
fetch-depth: 0
|
| 16 |
+
- name: Add remote
|
| 17 |
+
env:
|
| 18 |
+
HF: ${{secrets.HF_TOKEN }}
|
| 19 |
+
HFUSER: ${{secrets.HFUSER }}
|
| 20 |
+
run: git remote add space https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/model-optimization-lab
|
| 21 |
+
- name: Push to huggingface hub
|
| 22 |
+
env:
|
| 23 |
+
HF: ${{ secrets.HF_TOKEN}}
|
| 24 |
+
HFUSER: ${{secrets.HFUSER }}
|
| 25 |
+
run: git push --force https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/model-optimization-lab main
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pth
|
| 2 |
+
exports/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
README.md
CHANGED
|
@@ -1 +1,106 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Model Optimization Lab
|
| 3 |
+
emoji: π»
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Model Optimization Lab
|
| 13 |
+
|
| 14 |
+
Interactive Gradio playground for comparing pruning and quantization on ImageNet-classification backbones. Upload any image and observe how latency, confidence, and model size change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
- Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, etc.).
|
| 18 |
+
- Pruning tab: structured/unstructured pruning with configurable sparsity and size/latency comparison.
|
| 19 |
+
- Quantization tab: dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels.
|
| 20 |
+
- Automated metric tables and Top-5 bar charts to visualize confidence shifts between optimized variants.
|
| 21 |
+
|
| 22 |
+
## Requirements
|
| 23 |
+
- Python 3.9+
|
| 24 |
+
- PyTorch with CPU support (GPU optional but recommended for FP16 experiments).
|
| 25 |
+
- The packages listed in `requirements.txt` or installed via `pip install -r requirements.txt` (create the file if missing with entries like `torch`, `timm`, `gradio`, `pandas`, `torchvision`).
|
| 26 |
+
|
| 27 |
+
## Quick Start
|
| 28 |
+
1. Clone the repository:
|
| 29 |
+
```bash
|
| 30 |
+
git clone https://github.com/shriarul5273/model-optimization-lab.git
|
| 31 |
+
cd model-optimization-lab
|
| 32 |
+
```
|
| 33 |
+
2. Create and activate a virtual environment (optional but recommended).
|
| 34 |
+
3. Install dependencies:
|
| 35 |
+
```bash
|
| 36 |
+
pip install -r requirements.txt
|
| 37 |
+
```
|
| 38 |
+
4. Launch the Gradio app:
|
| 39 |
+
```bash
|
| 40 |
+
python app.py
|
| 41 |
+
```
|
| 42 |
+
5. Open the local Gradio URL (printed in the terminal) in your browser.
|
| 43 |
+
|
| 44 |
+
## Using the App
|
| 45 |
+
1. **Upload an image** or pick one of the provided examples.
|
| 46 |
+
2. Choose the **Base Model** dropdown (ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
|
| 47 |
+
3. Pick a **Hardware Preset** or keep `custom`:
|
| 48 |
+
- Edge CPU β CPU, channels-last off, dynamic quantization, 30% pruning.
|
| 49 |
+
- Datacenter GPU β CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
|
| 50 |
+
- Apple MPS β MPS, FP16 quantization, 20% pruning.
|
| 51 |
+
4. Pick a tab and set options, then click **Run**.
|
| 52 |
+
|
| 53 |
+
### Pruning tab options
|
| 54 |
+
- `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
|
| 55 |
+
- `Pruning Amount`: 0.1β0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support.
|
| 56 |
+
- `Device`: `auto` picks CUDA β MPS β CPU. Channels-last is only honored on CUDA.
|
| 57 |
+
- `Channels-last input (CUDA)`: Converts tensors to channels-last for better CUDA kernel throughput.
|
| 58 |
+
- `Mixed precision (AMP)`: Enables CUDA autocast for FP16/FP32 mixes.
|
| 59 |
+
- `Torch compile (PyTorch 2)`: Wraps the pruned model in `torch.compile` when available.
|
| 60 |
+
- Exports: TorchScript (`pruned_model.ts`), ONNX (`pruned_model.onnx`), JSON report, always saves `pruned_state_dict.pth`.
|
| 61 |
+
- Outputs: comparison metrics, Top-5 bar chart, per-layer sparsity table, download list of artifacts.
|
| 62 |
+
|
| 63 |
+
### Quantization tab options
|
| 64 |
+
- `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
|
| 65 |
+
- `Device`: `auto` picks CUDA β MPS β CPU; dynamic/weight-only runs force CPU execution for kernel support.
|
| 66 |
+
- `Channels-last input (CUDA)`: Same as pruning; ignored on CPU.
|
| 67 |
+
- `Mixed precision (AMP)`: Applies CUDA autocast to the quantized forward pass.
|
| 68 |
+
- `Torch compile (PyTorch 2)`: Compiles the quantized model when available.
|
| 69 |
+
- Exports: TorchScript (`quantized_model.ts`), ONNX (`quantized_model.onnx`), JSON report, always saves `quantized_state_dict.pth`.
|
| 70 |
+
- Outputs: comparison metrics, Top-5 bar chart, download list of artifacts.
|
| 71 |
+
|
| 72 |
+
### What gets exported
|
| 73 |
+
- Artifacts are written to `exports/`. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants.
|
| 74 |
+
- TorchScript/ONNX exports run best on CPU inputs; failures are logged to the console and skipped.
|
| 75 |
+
- State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere.
|
| 76 |
+
|
| 77 |
+
### Output Interpreting Tips
|
| 78 |
+
- **Top-1 Prediction**: Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g., `chambered nautilus, pearly nautilus`).
|
| 79 |
+
- **Latency (ms)**: Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model executionβsee [Performance Notes](#performance-notes).
|
| 80 |
+
- **Model Size (MB)**: Serialized state dictionary size after saving to disk.
|
| 81 |
+
|
| 82 |
+
## Performance Notes
|
| 83 |
+
- The current quantization pipeline rebuilds the optimized model on each request, so the reported latency includes setup time. Reusing pre-quantized instances will yield more realistic numbers.
|
| 84 |
+
- Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization.
|
| 85 |
+
- PyTorch default quantization backend may fall back to `qnnpack` on CPU. For x86 systems, set `torch.backends.quantized.engine = "fbgemm"` before quantization for best results.
|
| 86 |
+
- FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
|
| 87 |
+
|
| 88 |
+
## Extending the Lab
|
| 89 |
+
- Swap in different architectures by changing the `timm.create_model` call in `app.py`.
|
| 90 |
+
- Add calibration data and static INT8 quantization to include convolution layers.
|
| 91 |
+
- Cache optimized models to avoid recomputation across requests.
|
| 92 |
+
- Integrate evaluation datasets to quantify accuracy drop beyond top-1 confidence.
|
| 93 |
+
|
| 94 |
+
## CLI Mode
|
| 95 |
+
- Run without the UI: `python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto`
|
| 96 |
+
- Modes: `--mode prune` (structured pruning @ 0.4 sparsity) or `--mode quant` (dynamic quantization). Both emit the metrics table and export artifacts list.
|
| 97 |
+
- Devices: `auto` chooses CUDA β MPS β CPU based on availability; `cpu`/`cuda`/`mps` force a device. Dynamic/weight-only quantization forces CPU for kernel support even if GPU is requested.
|
| 98 |
+
- Models: any entry from `MODEL_OPTIONS` in `app.py`.
|
| 99 |
+
|
| 100 |
+
## Troubleshooting
|
| 101 |
+
- **Slow downloads**: The first run downloads pretrained weights (~100 MB). Subsequent runs use cached files.
|
| 102 |
+
- **CUDA errors**: Ensure the correct CUDA-enabled PyTorch build is installed if you intend to run on GPU.
|
| 103 |
+
- **Quantized model larger than expected**: The state dictionary includes dequantized tensors for some paths (e.g., dynamic quantization). Consider TorchScript or ONNX export for compact deployment artifacts.
|
| 104 |
+
|
| 105 |
+
## License
|
| 106 |
+
This project inherits the default license of the repository. Replace or update this section if you add a specific license.
|
app.py
CHANGED
|
@@ -1,33 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
| 3 |
import torch.nn.utils.prune as prune
|
| 4 |
-
import timm
|
| 5 |
-
import gradio as gr
|
| 6 |
-
from torchvision import transforms
|
| 7 |
from PIL import Image
|
| 8 |
-
import
|
| 9 |
-
import os
|
| 10 |
-
import pandas as pd
|
| 11 |
|
| 12 |
|
| 13 |
# ---------------------------------------------
|
| 14 |
-
# Base
|
| 15 |
# ---------------------------------------------
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# ---------------------------------------------
|
| 21 |
# Image Preprocess
|
| 22 |
# ---------------------------------------------
|
| 23 |
-
transform = transforms.Compose([
|
| 24 |
-
transforms.Resize((224, 224)),
|
| 25 |
-
transforms.ToTensor(),
|
| 26 |
-
transforms.Normalize([0.485, 0.456, 0.406],
|
| 27 |
-
[0.229, 0.224, 0.225])
|
| 28 |
-
])
|
| 29 |
-
|
| 30 |
-
# Get ImageNet labels - using class descriptions
|
| 31 |
imagenet_info = timm.data.ImageNetInfo()
|
| 32 |
labels = [imagenet_info.index_to_description(i) for i in range(1000)]
|
| 33 |
|
|
@@ -40,13 +134,15 @@ def apply_pruning(model, amount=0.5, method="unstructured"):
|
|
| 40 |
|
| 41 |
for module in model.modules():
|
| 42 |
if isinstance(module, nn.Conv2d):
|
| 43 |
-
|
| 44 |
if method == "unstructured":
|
| 45 |
prune.l1_unstructured(module, name="weight", amount=amount)
|
| 46 |
-
|
| 47 |
elif method == "structured":
|
| 48 |
prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
return model
|
| 51 |
|
| 52 |
|
|
@@ -54,52 +150,201 @@ def apply_pruning(model, amount=0.5, method="unstructured"):
|
|
| 54 |
# QUANTIZATION FUNCTION (dynamic)
|
| 55 |
# ---------------------------------------------
|
| 56 |
def apply_quantization(model, q_type="dynamic"):
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
elif q_type == "weight_only":
|
| 63 |
-
model_int8 = model
|
| 64 |
-
for name, module in model_int8.named_modules():
|
| 65 |
-
if isinstance(module, nn.Linear):
|
| 66 |
-
module.weight.data = torch.quantize_per_tensor(
|
| 67 |
-
module.weight.data, scale=0.1, zero_point=0, dtype=torch.qint8
|
| 68 |
-
).dequantize()
|
| 69 |
-
return model_int8
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
# ---------------------------------------------
|
| 78 |
# Inference Function (shared)
|
| 79 |
# ---------------------------------------------
|
| 80 |
-
def run_inference(model, image):
|
| 81 |
print(f" run_inference called, image type: {type(image)}")
|
| 82 |
-
|
| 83 |
-
raise ValueError("No image provided")
|
| 84 |
-
|
| 85 |
-
if not isinstance(image, Image.Image):
|
| 86 |
-
print(" Converting numpy array to PIL Image")
|
| 87 |
-
image = Image.fromarray(image.astype('uint8'))
|
| 88 |
|
| 89 |
-
|
| 90 |
-
img =
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
if next(model.parameters()).dtype == torch.float16:
|
| 93 |
img = img.half()
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
print(" Running model inference...")
|
| 96 |
start = time.time()
|
| 97 |
-
|
|
|
|
| 98 |
out = model(img)
|
| 99 |
latency = (time.time() - start) * 1000
|
| 100 |
|
| 101 |
-
|
| 102 |
-
prob = torch.softmax(
|
| 103 |
top5_prob, top5_idx = torch.topk(prob, 5)
|
| 104 |
|
| 105 |
results = [(labels[i], float(top5_prob[j])) for j, i in enumerate(top5_idx)]
|
|
@@ -107,47 +352,92 @@ def run_inference(model, image):
|
|
| 107 |
return results, latency
|
| 108 |
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
# ---------------------------------------------
|
| 111 |
# Gradio Functions With Options
|
| 112 |
# ---------------------------------------------
|
| 113 |
-
def run_pruned(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
print("\n=== RUN PRUNED CALLED ===")
|
| 115 |
-
print(f"Image type: {type(img)}, Method: {method}, Amount: {amount}")
|
| 116 |
-
|
| 117 |
if img is None:
|
| 118 |
print("ERROR: Image is None")
|
| 119 |
-
return
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Run original model
|
| 122 |
print("Running original model...")
|
| 123 |
-
|
|
|
|
| 124 |
print(f"Original model done. Latency: {latency_orig:.2f}ms")
|
| 125 |
-
|
| 126 |
# Run pruned model
|
| 127 |
print("Creating fresh model...")
|
| 128 |
-
fresh_model =
|
| 129 |
print("Applying pruning...")
|
| 130 |
pruned_model = apply_pruning(fresh_model, amount=float(amount), method=method)
|
|
|
|
| 131 |
print("Running pruned model...")
|
| 132 |
-
results_pruned, latency_pruned = run_inference(pruned_model, img)
|
| 133 |
print(f"Pruned model done. Latency: {latency_pruned:.2f}ms")
|
| 134 |
|
| 135 |
-
# Model sizes
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Make pruning permanent by removing the mask (this shows the actual reduced size)
|
| 140 |
-
for module in pruned_model.modules():
|
| 141 |
-
if isinstance(module, nn.Conv2d) and hasattr(module, 'weight_orig'):
|
| 142 |
-
prune.remove(module, 'weight')
|
| 143 |
-
|
| 144 |
-
torch.save(pruned_model.state_dict(), "pruned_model.pth")
|
| 145 |
-
size_orig = os.path.getsize("fp32_model.pth") / 1e6
|
| 146 |
-
size_pruned = os.path.getsize("pruned_model.pth") / 1e6
|
| 147 |
print(f"Model sizes - Original: {size_orig:.2f}MB, Pruned: {size_pruned:.2f}MB")
|
| 148 |
|
| 149 |
# Comparison metrics - as DataFrame for Gradio
|
| 150 |
-
print("Creating metrics dataframe...")
|
| 151 |
metrics_df = pd.DataFrame({
|
| 152 |
"Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
|
| 153 |
"Original Model": [
|
|
@@ -163,45 +453,118 @@ def run_pruned(img, method, amount):
|
|
| 163 |
f"{size_pruned:.2f}"
|
| 164 |
]
|
| 165 |
})
|
| 166 |
-
|
| 167 |
-
# Top-5 predictions chart data - as DataFrame for BarPlot
|
| 168 |
-
print("Preparing chart data...")
|
| 169 |
-
chart_df = pd.DataFrame({
|
| 170 |
-
"Class": [results_orig[i][0] for i in range(5)],
|
| 171 |
-
"Original": [results_orig[i][1] for i in range(5)],
|
| 172 |
-
"Pruned": [results_pruned[i][1] for i in range(5)]
|
| 173 |
-
})
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
print("\n=== RUN QUANTIZED CALLED ===")
|
| 181 |
-
print(f"Image type: {type(img)}, Q-type: {q_type}")
|
| 182 |
-
|
| 183 |
if img is None:
|
| 184 |
print("ERROR: Image is None")
|
| 185 |
-
return
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# Run original model
|
| 188 |
print("Running original model...")
|
| 189 |
-
|
|
|
|
| 190 |
print(f"Original model done. Latency: {latency_orig:.2f}ms")
|
| 191 |
-
|
| 192 |
# Run quantized model
|
| 193 |
-
fresh_model =
|
| 194 |
quant_model = apply_quantization(fresh_model, q_type)
|
| 195 |
-
|
|
|
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
torch.save(quant_model.state_dict(), "quantized_model.pth")
|
| 200 |
-
size_orig = os.path.getsize("fp32_model.pth") / 1e6
|
| 201 |
-
size_quant = os.path.getsize("quantized_model.pth") / 1e6
|
| 202 |
|
| 203 |
-
# Comparison metrics - as DataFrame for Gradio
|
| 204 |
-
print("Creating metrics dataframe...")
|
| 205 |
metrics_df = pd.DataFrame({
|
| 206 |
"Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
|
| 207 |
"Original Model": [
|
|
@@ -217,95 +580,371 @@ def run_quantized(img, q_type):
|
|
| 217 |
f"{size_quant:.2f}"
|
| 218 |
]
|
| 219 |
})
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
print("=== RUN QUANTIZED COMPLETE ===")
|
| 230 |
-
return metrics_df,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
|
| 233 |
# ---------------------------------------------
|
| 234 |
# GRADIO UI
|
| 235 |
# ---------------------------------------------
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
]
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
with gr.
|
| 253 |
-
with gr.
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
value="structured",
|
| 259 |
-
label="Pruning
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
label="
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
label="
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import zipfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tempfile import TemporaryDirectory
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import timm
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
import torch.nn.utils.prune as prune
|
|
|
|
|
|
|
|
|
|
| 19 |
from PIL import Image
|
| 20 |
+
from torchvision import transforms
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
# ---------------------------------------------
|
| 24 |
+
# Base Model Registry / Defaults
|
| 25 |
# ---------------------------------------------
|
| 26 |
+
MODEL_OPTIONS = [
|
| 27 |
+
"resnet50",
|
| 28 |
+
"mobilenetv3_large_100",
|
| 29 |
+
"efficientnet_b0",
|
| 30 |
+
"convnext_tiny",
|
| 31 |
+
"vit_base_patch16_224",
|
| 32 |
+
"regnety_016",
|
| 33 |
+
"efficientnet_lite0",
|
| 34 |
+
]
|
| 35 |
+
PRETRAINED_DEFAULT = os.getenv("MODEL_OPT_PRETRAINED", "1") == "1"
|
| 36 |
+
_PRETRAINED_DISABLED = False
|
| 37 |
+
_PRETRAINED_WARNED = False
|
| 38 |
+
|
| 39 |
+
PRESETS = {
|
| 40 |
+
"Edge CPU": {
|
| 41 |
+
"device": "cpu",
|
| 42 |
+
"channels_last": False,
|
| 43 |
+
"compile": False,
|
| 44 |
+
"quant": "dynamic",
|
| 45 |
+
"prune_amount": 0.3,
|
| 46 |
+
},
|
| 47 |
+
"Datacenter GPU": {
|
| 48 |
+
"device": "cuda",
|
| 49 |
+
"channels_last": True,
|
| 50 |
+
"compile": True,
|
| 51 |
+
"quant": "fp16",
|
| 52 |
+
"prune_amount": 0.2,
|
| 53 |
+
},
|
| 54 |
+
"Apple MPS": {
|
| 55 |
+
"device": "mps",
|
| 56 |
+
"channels_last": False,
|
| 57 |
+
"compile": False,
|
| 58 |
+
"quant": "fp16",
|
| 59 |
+
"prune_amount": 0.2,
|
| 60 |
+
},
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
_MODEL_CACHE: dict[str, torch.nn.Module] = {}
|
| 64 |
+
_TRANSFORM_CACHE: dict[str, transforms.Compose] = {}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def select_device(device_str: str) -> torch.device:
|
| 68 |
+
"""Return a valid torch.device based on user selection."""
|
| 69 |
+
device_str = (device_str or "auto").lower()
|
| 70 |
+
if device_str == "cuda" and torch.cuda.is_available():
|
| 71 |
+
return torch.device("cuda")
|
| 72 |
+
if device_str == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 73 |
+
return torch.device("mps")
|
| 74 |
+
if device_str == "cpu":
|
| 75 |
+
return torch.device("cpu")
|
| 76 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_transform(model_name: str):
|
| 80 |
+
if model_name in _TRANSFORM_CACHE:
|
| 81 |
+
return _TRANSFORM_CACHE[model_name]
|
| 82 |
+
|
| 83 |
+
model = get_fp32_model(model_name)
|
| 84 |
+
|
| 85 |
+
if hasattr(timm.data, "resolve_model_data_config"):
|
| 86 |
+
data_cfg = timm.data.resolve_model_data_config(model)
|
| 87 |
+
else:
|
| 88 |
+
# Fallback for older timm versions
|
| 89 |
+
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
|
| 90 |
+
|
| 91 |
+
_TRANSFORM_CACHE[model_name] = timm.data.create_transform(**data_cfg)
|
| 92 |
+
return _TRANSFORM_CACHE[model_name]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_fp32_model(model_name: str):
|
| 96 |
+
if model_name not in _MODEL_CACHE:
|
| 97 |
+
global _PRETRAINED_DISABLED, _PRETRAINED_WARNED
|
| 98 |
+
use_pretrained = PRETRAINED_DEFAULT and not _PRETRAINED_DISABLED
|
| 99 |
+
if not use_pretrained and not _PRETRAINED_WARNED:
|
| 100 |
+
print("INFO: MODEL_OPT_PRETRAINED disabled; using randomly initialized weights.")
|
| 101 |
+
_PRETRAINED_WARNED = True
|
| 102 |
+
try:
|
| 103 |
+
loaded = timm.create_model(model_name, pretrained=use_pretrained)
|
| 104 |
+
except Exception as exc:
|
| 105 |
+
print(f"Warning: pretrained weights unavailable ({exc}); using random init for {model_name}")
|
| 106 |
+
_PRETRAINED_DISABLED = True
|
| 107 |
+
loaded = timm.create_model(model_name, pretrained=False)
|
| 108 |
+
loaded.eval()
|
| 109 |
+
_MODEL_CACHE[model_name] = loaded
|
| 110 |
+
return _MODEL_CACHE[model_name]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def clone_model(model_name: str):
|
| 114 |
+
"""Create a fresh model loaded from the cached FP32 weights to avoid re-downloads."""
|
| 115 |
+
base = get_fp32_model(model_name)
|
| 116 |
+
fresh = timm.create_model(model_name, pretrained=False)
|
| 117 |
+
fresh.load_state_dict(base.state_dict())
|
| 118 |
+
fresh.eval()
|
| 119 |
+
return fresh
|
| 120 |
|
| 121 |
|
| 122 |
# ---------------------------------------------
|
| 123 |
# Image Preprocess
|
| 124 |
# ---------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
imagenet_info = timm.data.ImageNetInfo()
|
| 126 |
labels = [imagenet_info.index_to_description(i) for i in range(1000)]
|
| 127 |
|
|
|
|
| 134 |
|
| 135 |
for module in model.modules():
|
| 136 |
if isinstance(module, nn.Conv2d):
|
|
|
|
| 137 |
if method == "unstructured":
|
| 138 |
prune.l1_unstructured(module, name="weight", amount=amount)
|
|
|
|
| 139 |
elif method == "structured":
|
| 140 |
prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
|
| 141 |
|
| 142 |
+
# Make pruning permanent to reflect true sparsity/size
|
| 143 |
+
for module in model.modules():
|
| 144 |
+
if isinstance(module, nn.Conv2d) and hasattr(module, "weight_orig"):
|
| 145 |
+
prune.remove(module, "weight")
|
| 146 |
return model
|
| 147 |
|
| 148 |
|
|
|
|
| 150 |
# QUANTIZATION FUNCTION (dynamic)
|
| 151 |
# ---------------------------------------------
|
| 152 |
def apply_quantization(model, q_type="dynamic"):
|
| 153 |
+
q_type = q_type or "dynamic"
|
| 154 |
+
if q_type in {"dynamic", "weight_only"}: # dynamic quantization is weight-only by default
|
| 155 |
+
return torch.ao.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
|
| 156 |
+
if q_type == "fp16":
|
| 157 |
+
return model.half().eval()
|
| 158 |
+
return model
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
def compute_sparsity(model: nn.Module) -> pd.DataFrame:
|
| 162 |
+
rows = []
|
| 163 |
+
for name, module in model.named_modules():
|
| 164 |
+
if hasattr(module, "weight"):
|
| 165 |
+
weight = module.weight.detach()
|
| 166 |
+
total = weight.numel()
|
| 167 |
+
zeros = (weight == 0).sum().item()
|
| 168 |
+
sparsity = 100.0 * zeros / max(total, 1)
|
| 169 |
+
rows.append({"Layer": name, "Params": total, "Sparsity %": round(sparsity, 2)})
|
| 170 |
+
return pd.DataFrame(rows)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _macs_conv2d(inp, module: nn.Conv2d, out):
|
| 174 |
+
# inp: (N, C_in, H, W), out: (N, C_out, H_out, W_out)
|
| 175 |
+
batch, c_in, h, w = inp.shape
|
| 176 |
+
_, c_out, h_out, w_out = out.shape
|
| 177 |
+
kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (c_in / module.groups)
|
| 178 |
+
return batch * h_out * w_out * c_out * kernel_ops
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _macs_linear(inp, module: nn.Linear, out):
|
| 182 |
+
batch = inp.shape[0]
|
| 183 |
+
return batch * module.in_features * module.out_features
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def layer_profile(model: nn.Module, sample_input: torch.Tensor) -> pd.DataFrame:
|
| 187 |
+
"""Collect per-layer params, MACs, and forward time (single run)."""
|
| 188 |
+
rows = []
|
| 189 |
+
handles = []
|
| 190 |
+
start_times = {}
|
| 191 |
+
|
| 192 |
+
def pre_hook(name):
|
| 193 |
+
def _pre(mod, inp):
|
| 194 |
+
start_times[name] = time.time()
|
| 195 |
+
return _pre
|
| 196 |
+
|
| 197 |
+
def fwd_hook(name):
|
| 198 |
+
def _fwd(mod, inp, out):
|
| 199 |
+
end = time.time()
|
| 200 |
+
duration_ms = (end - start_times.get(name, end)) * 1000
|
| 201 |
+
inp0 = inp[0] if isinstance(inp, (tuple, list)) else inp
|
| 202 |
+
macs = None
|
| 203 |
+
if isinstance(mod, nn.Conv2d):
|
| 204 |
+
macs = _macs_conv2d(inp0, mod, out)
|
| 205 |
+
elif isinstance(mod, nn.Linear):
|
| 206 |
+
macs = _macs_linear(inp0, mod, out)
|
| 207 |
+
params = sum(p.numel() for p in mod.parameters())
|
| 208 |
+
rows.append({
|
| 209 |
+
"Layer": name,
|
| 210 |
+
"Type": mod.__class__.__name__,
|
| 211 |
+
"Params": params,
|
| 212 |
+
"MACs": macs if macs is None else float(macs),
|
| 213 |
+
"Latency (ms)": round(duration_ms, 3),
|
| 214 |
+
})
|
| 215 |
+
return _fwd
|
| 216 |
+
|
| 217 |
+
for name, module in model.named_modules():
|
| 218 |
+
if len(list(module.children())) == 0: # leaf
|
| 219 |
+
handles.append(module.register_forward_pre_hook(pre_hook(name)))
|
| 220 |
+
handles.append(module.register_forward_hook(fwd_hook(name)))
|
| 221 |
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
model(sample_input)
|
| 224 |
+
|
| 225 |
+
for h in handles:
|
| 226 |
+
h.remove()
|
| 227 |
+
|
| 228 |
+
return pd.DataFrame(rows)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def maybe_compile(model, use_compile: bool):
|
| 232 |
+
if not use_compile:
|
| 233 |
+
return model
|
| 234 |
+
if not hasattr(torch, "compile"):
|
| 235 |
+
return model
|
| 236 |
+
try:
|
| 237 |
+
return torch.compile(model)
|
| 238 |
+
except Exception:
|
| 239 |
+
return model
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_state_dict_size_mb(model: nn.Module) -> float:
|
| 243 |
+
buffer = io.BytesIO()
|
| 244 |
+
torch.save(model.state_dict(), buffer)
|
| 245 |
+
return len(buffer.getbuffer()) / 1e6
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def prepare_image(image, transform_fn):
|
| 249 |
+
if image is None:
|
| 250 |
+
raise ValueError("No image provided")
|
| 251 |
+
|
| 252 |
+
if not isinstance(image, Image.Image):
|
| 253 |
+
if isinstance(image, np.ndarray) and image.dtype != np.uint8:
|
| 254 |
+
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
|
| 255 |
+
image = Image.fromarray(image.astype("uint8"))
|
| 256 |
+
|
| 257 |
+
image = image.convert("RGB")
|
| 258 |
+
tensor = transform_fn(image).unsqueeze(0)
|
| 259 |
+
return tensor
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def grad_cam(image, model, device, transform_fn):
|
| 263 |
+
model.eval()
|
| 264 |
+
target_layer = None
|
| 265 |
+
for m in reversed(list(model.modules())):
|
| 266 |
+
if isinstance(m, nn.Conv2d):
|
| 267 |
+
target_layer = m
|
| 268 |
+
break
|
| 269 |
+
if target_layer is None:
|
| 270 |
+
raise ValueError("No Conv2d layer found for Grad-CAM")
|
| 271 |
+
|
| 272 |
+
activations = {}
|
| 273 |
+
gradients = {}
|
| 274 |
+
|
| 275 |
+
def fwd_hook(module, inp, out):
|
| 276 |
+
activations["value"] = out.detach()
|
| 277 |
+
|
| 278 |
+
def bwd_hook(module, grad_in, grad_out):
|
| 279 |
+
gradients["value"] = grad_out[0].detach()
|
| 280 |
+
|
| 281 |
+
handle_fwd = target_layer.register_forward_hook(fwd_hook)
|
| 282 |
+
handle_bwd = target_layer.register_full_backward_hook(bwd_hook)
|
| 283 |
+
|
| 284 |
+
img_t = prepare_image(image, transform_fn).to(device)
|
| 285 |
+
img_t.requires_grad_(True)
|
| 286 |
+
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
pass
|
| 289 |
+
|
| 290 |
+
out = model(img_t)
|
| 291 |
+
top1 = out.argmax(dim=1)
|
| 292 |
+
score = out[0, top1]
|
| 293 |
+
model.zero_grad()
|
| 294 |
+
score.backward()
|
| 295 |
+
|
| 296 |
+
act = activations["value"]
|
| 297 |
+
grad = gradients["value"]
|
| 298 |
+
weights = grad.mean(dim=(2, 3), keepdim=True)
|
| 299 |
+
cam = (weights * act).sum(dim=1, keepdim=True)
|
| 300 |
+
cam = F.relu(cam)
|
| 301 |
+
cam = cam.squeeze().cpu().numpy()
|
| 302 |
+
cam -= cam.min()
|
| 303 |
+
cam /= cam.max() + 1e-8
|
| 304 |
+
|
| 305 |
+
# Resize CAM to image size
|
| 306 |
+
cam_img = Image.fromarray(np.uint8(cam * 255)).resize(image.size, resample=Image.BILINEAR)
|
| 307 |
+
heatmap = np.array(cam_img)
|
| 308 |
+
heatmap_rgb = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
|
| 309 |
+
overlay = np.array(image.convert("RGB"), dtype=np.float32)
|
| 310 |
+
alpha = 0.35
|
| 311 |
+
blended = (overlay * (1 - alpha) + heatmap_rgb * alpha).clip(0, 255).astype("uint8")
|
| 312 |
+
blended_img = Image.fromarray(blended)
|
| 313 |
+
|
| 314 |
+
handle_fwd.remove()
|
| 315 |
+
handle_bwd.remove()
|
| 316 |
+
return blended_img
|
| 317 |
|
| 318 |
|
| 319 |
# ---------------------------------------------
|
| 320 |
# Inference Function (shared)
|
| 321 |
# ---------------------------------------------
|
| 322 |
+
def run_inference(model, image, device, transform_fn, channels_last=False, warmup=False, use_amp=False):
|
| 323 |
print(f" run_inference called, image type: {type(image)}")
|
| 324 |
+
img = prepare_image(image, transform_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
+
model = model.to(device)
|
| 327 |
+
img = img.to(device)
|
| 328 |
+
|
| 329 |
+
if channels_last and device.type == "cuda":
|
| 330 |
+
img = img.to(memory_format=torch.channels_last)
|
| 331 |
|
| 332 |
if next(model.parameters()).dtype == torch.float16:
|
| 333 |
img = img.half()
|
| 334 |
|
| 335 |
+
if warmup:
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
model(img)
|
| 338 |
+
|
| 339 |
print(" Running model inference...")
|
| 340 |
start = time.time()
|
| 341 |
+
amp_ctx = torch.cuda.amp.autocast(enabled=use_amp and device.type == "cuda")
|
| 342 |
+
with torch.no_grad(), amp_ctx:
|
| 343 |
out = model(img)
|
| 344 |
latency = (time.time() - start) * 1000
|
| 345 |
|
| 346 |
+
out_cpu = out.detach().cpu()
|
| 347 |
+
prob = torch.softmax(out_cpu, dim=1)[0]
|
| 348 |
top5_prob, top5_idx = torch.topk(prob, 5)
|
| 349 |
|
| 350 |
results = [(labels[i], float(top5_prob[j])) for j, i in enumerate(top5_idx)]
|
|
|
|
| 352 |
return results, latency
|
| 353 |
|
| 354 |
|
| 355 |
+
def build_top5_plot(results_orig, results_other, other_label: str):
|
| 356 |
+
classes = []
|
| 357 |
+
for r in results_orig + results_other:
|
| 358 |
+
if r[0] not in classes:
|
| 359 |
+
classes.append(r[0])
|
| 360 |
+
orig_map = {r[0]: r[1] for r in results_orig}
|
| 361 |
+
other_map = {r[0]: r[1] for r in results_other}
|
| 362 |
+
orig_vals = [orig_map.get(c, 0.0) for c in classes]
|
| 363 |
+
other_vals = [other_map.get(c, 0.0) for c in classes]
|
| 364 |
+
x = np.arange(len(classes))
|
| 365 |
+
width = 0.35
|
| 366 |
+
fig, ax = plt.subplots(figsize=(8, 4))
|
| 367 |
+
ax.bar(x - width / 2, orig_vals, width, label="Original")
|
| 368 |
+
ax.bar(x + width / 2, other_vals, width, label=other_label)
|
| 369 |
+
ax.set_xticks(x)
|
| 370 |
+
ax.set_xticklabels(classes, rotation=20, ha="right")
|
| 371 |
+
ax.set_ylabel("Confidence")
|
| 372 |
+
ax.set_xlabel("Class")
|
| 373 |
+
ax.legend()
|
| 374 |
+
fig.tight_layout()
|
| 375 |
+
return fig
|
| 376 |
+
|
| 377 |
+
|
| 378 |
# ---------------------------------------------
|
| 379 |
# Gradio Functions With Options
|
| 380 |
# ---------------------------------------------
|
| 381 |
+
def run_pruned(
|
| 382 |
+
img,
|
| 383 |
+
model_name,
|
| 384 |
+
method,
|
| 385 |
+
amount,
|
| 386 |
+
device_choice="auto",
|
| 387 |
+
channels_last=False,
|
| 388 |
+
use_compile=False,
|
| 389 |
+
use_amp=False,
|
| 390 |
+
export_ts=False,
|
| 391 |
+
export_onnx=False,
|
| 392 |
+
export_report=False,
|
| 393 |
+
export_state=True,
|
| 394 |
+
preset=None,
|
| 395 |
+
):
|
| 396 |
print("\n=== RUN PRUNED CALLED ===")
|
| 397 |
+
print(f"Image type: {type(img)}, Model: {model_name}, Method: {method}, Amount: {amount}")
|
| 398 |
+
|
| 399 |
if img is None:
|
| 400 |
print("ERROR: Image is None")
|
| 401 |
+
return (
|
| 402 |
+
{"Metric": ["Error"], "Original Model": ["No image uploaded"], "Pruned Model": [""]},
|
| 403 |
+
{"Class": [], "Original": [], "Pruned": []},
|
| 404 |
+
pd.DataFrame(),
|
| 405 |
+
[]
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if preset in PRESETS:
|
| 409 |
+
preset_cfg = PRESETS[preset]
|
| 410 |
+
device_choice = preset_cfg["device"]
|
| 411 |
+
channels_last = preset_cfg["channels_last"]
|
| 412 |
+
use_compile = preset_cfg["compile"]
|
| 413 |
+
use_amp = preset_cfg.get("amp", use_amp)
|
| 414 |
+
amount = preset_cfg.get("prune_amount", amount)
|
| 415 |
+
|
| 416 |
+
device = select_device(device_choice)
|
| 417 |
+
transform_fn = get_transform(model_name)
|
| 418 |
+
|
| 419 |
# Run original model
|
| 420 |
print("Running original model...")
|
| 421 |
+
fp32_model = get_fp32_model(model_name)
|
| 422 |
+
results_orig, latency_orig = run_inference(fp32_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
|
| 423 |
print(f"Original model done. Latency: {latency_orig:.2f}ms")
|
| 424 |
+
|
| 425 |
# Run pruned model
|
| 426 |
print("Creating fresh model...")
|
| 427 |
+
fresh_model = clone_model(model_name)
|
| 428 |
print("Applying pruning...")
|
| 429 |
pruned_model = apply_pruning(fresh_model, amount=float(amount), method=method)
|
| 430 |
+
pruned_model = maybe_compile(pruned_model, use_compile)
|
| 431 |
print("Running pruned model...")
|
| 432 |
+
results_pruned, latency_pruned = run_inference(pruned_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
|
| 433 |
print(f"Pruned model done. Latency: {latency_pruned:.2f}ms")
|
| 434 |
|
| 435 |
+
# Model sizes (in-memory)
|
| 436 |
+
size_orig = get_state_dict_size_mb(fp32_model)
|
| 437 |
+
size_pruned = get_state_dict_size_mb(pruned_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
print(f"Model sizes - Original: {size_orig:.2f}MB, Pruned: {size_pruned:.2f}MB")
|
| 439 |
|
| 440 |
# Comparison metrics - as DataFrame for Gradio
|
|
|
|
| 441 |
metrics_df = pd.DataFrame({
|
| 442 |
"Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
|
| 443 |
"Original Model": [
|
|
|
|
| 453 |
f"{size_pruned:.2f}"
|
| 454 |
]
|
| 455 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
+
chart_fig = build_top5_plot(results_orig, results_pruned, "Pruned")
|
| 458 |
+
sparsity_df = compute_sparsity(pruned_model.cpu())
|
| 459 |
+
|
| 460 |
+
downloads = []
|
| 461 |
+
export_dir = Path("exports")
|
| 462 |
+
export_dir.mkdir(exist_ok=True)
|
| 463 |
+
sample_cpu = prepare_image(img, transform_fn)
|
| 464 |
+
|
| 465 |
+
if export_report:
|
| 466 |
+
report_path = export_dir / "pruned_report.json"
|
| 467 |
+
report = {
|
| 468 |
+
"model": model_name,
|
| 469 |
+
"pruning": {"method": method, "amount": float(amount)},
|
| 470 |
+
"metrics": metrics_df.to_dict(),
|
| 471 |
+
"top5_pruned": results_pruned,
|
| 472 |
+
"top5_original": results_orig,
|
| 473 |
+
}
|
| 474 |
+
report_path.write_text(json.dumps(report, indent=2))
|
| 475 |
+
downloads.append(str(report_path))
|
| 476 |
+
|
| 477 |
+
# Always allow state_dict download for reproducibility
|
| 478 |
+
if export_state:
|
| 479 |
+
state_path = export_dir / "pruned_state_dict.pth"
|
| 480 |
+
torch.save(pruned_model.state_dict(), state_path)
|
| 481 |
+
downloads.append(str(state_path))
|
| 482 |
+
|
| 483 |
+
if export_ts:
|
| 484 |
+
ts_path = export_dir / "pruned_model.ts"
|
| 485 |
+
try:
|
| 486 |
+
scripted = torch.jit.trace(pruned_model.cpu(), sample_cpu)
|
| 487 |
+
scripted.save(ts_path)
|
| 488 |
+
downloads.append(str(ts_path))
|
| 489 |
+
except Exception as exc:
|
| 490 |
+
print(f"TorchScript export failed: {exc}")
|
| 491 |
+
|
| 492 |
+
if export_onnx:
|
| 493 |
+
onnx_path = export_dir / "pruned_model.onnx"
|
| 494 |
+
try:
|
| 495 |
+
torch.onnx.export(
|
| 496 |
+
pruned_model.cpu(),
|
| 497 |
+
sample_cpu,
|
| 498 |
+
onnx_path,
|
| 499 |
+
input_names=["input"],
|
| 500 |
+
output_names=["output"],
|
| 501 |
+
opset_version=13,
|
| 502 |
+
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
|
| 503 |
+
)
|
| 504 |
+
downloads.append(str(onnx_path))
|
| 505 |
+
except Exception as exc:
|
| 506 |
+
print(f"ONNX export failed: {exc}")
|
| 507 |
|
| 508 |
+
print("=== RUN PRUNED COMPLETE ===")
|
| 509 |
+
return metrics_df, chart_fig, sparsity_df, downloads
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def run_quantized(
|
| 513 |
+
img,
|
| 514 |
+
model_name,
|
| 515 |
+
q_type,
|
| 516 |
+
device_choice="auto",
|
| 517 |
+
channels_last=False,
|
| 518 |
+
use_compile=False,
|
| 519 |
+
use_amp=False,
|
| 520 |
+
export_ts=False,
|
| 521 |
+
export_onnx=False,
|
| 522 |
+
export_report=False,
|
| 523 |
+
export_state=True,
|
| 524 |
+
preset=None,
|
| 525 |
+
):
|
| 526 |
print("\n=== RUN QUANTIZED CALLED ===")
|
| 527 |
+
print(f"Image type: {type(img)}, Model: {model_name}, Q-type: {q_type}")
|
| 528 |
+
|
| 529 |
if img is None:
|
| 530 |
print("ERROR: Image is None")
|
| 531 |
+
return (
|
| 532 |
+
{"Metric": ["Error"], "Original Model": ["No image uploaded"], "Quantized Model": [""]},
|
| 533 |
+
{"Class": [], "Original": [], "Quantized": []},
|
| 534 |
+
[]
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
if preset in PRESETS:
|
| 538 |
+
preset_cfg = PRESETS[preset]
|
| 539 |
+
device_choice = preset_cfg["device"]
|
| 540 |
+
channels_last = preset_cfg["channels_last"]
|
| 541 |
+
use_compile = preset_cfg["compile"]
|
| 542 |
+
use_amp = preset_cfg.get("amp", use_amp)
|
| 543 |
+
q_type = preset_cfg.get("quant", q_type)
|
| 544 |
+
|
| 545 |
+
device = select_device(device_choice)
|
| 546 |
+
if q_type in {"dynamic", "weight_only"} and device.type != "cpu":
|
| 547 |
+
print("Dynamic/weight-only quantization uses CPU kernels; switching device to CPU.")
|
| 548 |
+
device = torch.device("cpu")
|
| 549 |
+
channels_last = False
|
| 550 |
+
use_amp = False
|
| 551 |
+
transform_fn = get_transform(model_name)
|
| 552 |
+
|
| 553 |
# Run original model
|
| 554 |
print("Running original model...")
|
| 555 |
+
fp32_model = get_fp32_model(model_name)
|
| 556 |
+
results_orig, latency_orig = run_inference(fp32_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
|
| 557 |
print(f"Original model done. Latency: {latency_orig:.2f}ms")
|
| 558 |
+
|
| 559 |
# Run quantized model
|
| 560 |
+
fresh_model = clone_model(model_name)
|
| 561 |
quant_model = apply_quantization(fresh_model, q_type)
|
| 562 |
+
quant_model = maybe_compile(quant_model, use_compile)
|
| 563 |
+
results_quant, latency_quant = run_inference(quant_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
|
| 564 |
|
| 565 |
+
size_orig = get_state_dict_size_mb(fp32_model)
|
| 566 |
+
size_quant = get_state_dict_size_mb(quant_model)
|
|
|
|
|
|
|
|
|
|
| 567 |
|
|
|
|
|
|
|
| 568 |
metrics_df = pd.DataFrame({
|
| 569 |
"Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
|
| 570 |
"Original Model": [
|
|
|
|
| 580 |
f"{size_quant:.2f}"
|
| 581 |
]
|
| 582 |
})
|
| 583 |
+
|
| 584 |
+
chart_fig = build_top5_plot(results_orig, results_quant, "Quantized")
|
| 585 |
+
|
| 586 |
+
downloads = []
|
| 587 |
+
export_dir = Path("exports")
|
| 588 |
+
export_dir.mkdir(exist_ok=True)
|
| 589 |
+
sample_cpu = prepare_image(img, transform_fn)
|
| 590 |
+
|
| 591 |
+
if export_report:
|
| 592 |
+
report_path = export_dir / "quant_report.json"
|
| 593 |
+
report = {
|
| 594 |
+
"model": model_name,
|
| 595 |
+
"quantization": q_type,
|
| 596 |
+
"metrics": metrics_df.to_dict(),
|
| 597 |
+
"top5_quantized": results_quant,
|
| 598 |
+
"top5_original": results_orig,
|
| 599 |
+
}
|
| 600 |
+
report_path.write_text(json.dumps(report, indent=2))
|
| 601 |
+
downloads.append(str(report_path))
|
| 602 |
+
|
| 603 |
+
if export_state:
|
| 604 |
+
state_path = export_dir / "quantized_state_dict.pth"
|
| 605 |
+
torch.save(quant_model.state_dict(), state_path)
|
| 606 |
+
downloads.append(str(state_path))
|
| 607 |
+
|
| 608 |
+
if export_ts:
|
| 609 |
+
ts_path = export_dir / "quantized_model.ts"
|
| 610 |
+
try:
|
| 611 |
+
scripted = torch.jit.trace(quant_model.cpu(), sample_cpu)
|
| 612 |
+
scripted.save(ts_path)
|
| 613 |
+
downloads.append(str(ts_path))
|
| 614 |
+
except Exception as exc:
|
| 615 |
+
print(f"TorchScript export failed: {exc}")
|
| 616 |
+
|
| 617 |
+
if export_onnx:
|
| 618 |
+
onnx_path = export_dir / "quantized_model.onnx"
|
| 619 |
+
try:
|
| 620 |
+
torch.onnx.export(
|
| 621 |
+
quant_model.cpu(),
|
| 622 |
+
sample_cpu,
|
| 623 |
+
onnx_path,
|
| 624 |
+
input_names=["input"],
|
| 625 |
+
output_names=["output"],
|
| 626 |
+
opset_version=13,
|
| 627 |
+
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
|
| 628 |
+
)
|
| 629 |
+
downloads.append(str(onnx_path))
|
| 630 |
+
except Exception as exc:
|
| 631 |
+
print(f"ONNX export failed: {exc}")
|
| 632 |
|
| 633 |
print("=== RUN QUANTIZED COMPLETE ===")
|
| 634 |
+
return metrics_df, chart_fig, downloads
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def run_profile(image, model_name, variant, device_choice="auto"):
|
| 638 |
+
if image is None:
|
| 639 |
+
return pd.DataFrame()
|
| 640 |
+
device = select_device(device_choice)
|
| 641 |
+
if variant == "quant" and device.type != "cpu":
|
| 642 |
+
device = torch.device("cpu")
|
| 643 |
+
transform_fn = get_transform(model_name)
|
| 644 |
+
sample = prepare_image(image, transform_fn).to(device)
|
| 645 |
+
|
| 646 |
+
if variant == "fp32":
|
| 647 |
+
model = get_fp32_model(model_name).to(device)
|
| 648 |
+
elif variant == "pruned":
|
| 649 |
+
model = apply_pruning(clone_model(model_name), amount=0.4)
|
| 650 |
+
else:
|
| 651 |
+
model = apply_quantization(clone_model(model_name), "dynamic")
|
| 652 |
+
|
| 653 |
+
profile_df = layer_profile(model, sample)
|
| 654 |
+
return profile_df
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def run_batch(images, model_name, mode, device_choice="auto"):
|
| 658 |
+
"""Batch runner: returns per-image metrics and aggregate stats."""
|
| 659 |
+
if not images:
|
| 660 |
+
return pd.DataFrame(), pd.DataFrame()
|
| 661 |
+
|
| 662 |
+
device = select_device(device_choice)
|
| 663 |
+
transform_fn = get_transform(model_name)
|
| 664 |
+
|
| 665 |
+
per_image = []
|
| 666 |
+
latencies = []
|
| 667 |
+
labels_map = {}
|
| 668 |
+
expanded_files = []
|
| 669 |
+
temp_dirs = []
|
| 670 |
+
|
| 671 |
+
for path in images:
|
| 672 |
+
if isinstance(path, str) and path.endswith(".zip"):
|
| 673 |
+
td = TemporaryDirectory()
|
| 674 |
+
temp_dirs.append(td) # keep alive until function ends
|
| 675 |
+
with zipfile.ZipFile(path) as zf:
|
| 676 |
+
zf.extractall(td.name)
|
| 677 |
+
for root, _, files in os.walk(td.name):
|
| 678 |
+
for f in files:
|
| 679 |
+
if f.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 680 |
+
expanded_files.append(os.path.join(root, f))
|
| 681 |
+
if f.lower() in {"labels.txt", "labels.csv"}:
|
| 682 |
+
with open(os.path.join(root, f)) as lf:
|
| 683 |
+
for line in lf:
|
| 684 |
+
parts = line.strip().split(",")
|
| 685 |
+
if len(parts) >= 2:
|
| 686 |
+
labels_map[parts[0]] = parts[1]
|
| 687 |
+
else:
|
| 688 |
+
expanded_files.append(path)
|
| 689 |
+
|
| 690 |
+
for path in expanded_files:
|
| 691 |
+
img = Image.open(path) if isinstance(path, str) else path
|
| 692 |
+
if mode == "prune":
|
| 693 |
+
metrics, _, _, _ = run_pruned(
|
| 694 |
+
img,
|
| 695 |
+
model_name,
|
| 696 |
+
"structured",
|
| 697 |
+
0.4,
|
| 698 |
+
device_choice=device_choice,
|
| 699 |
+
export_state=False,
|
| 700 |
+
)
|
| 701 |
+
latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Pruned Model"].values[0])
|
| 702 |
+
top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Pruned Model"].values[0]
|
| 703 |
+
else:
|
| 704 |
+
metrics, _, _ = run_quantized(
|
| 705 |
+
img,
|
| 706 |
+
model_name,
|
| 707 |
+
"dynamic",
|
| 708 |
+
device_choice=device_choice,
|
| 709 |
+
export_state=False,
|
| 710 |
+
)
|
| 711 |
+
latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Quantized Model"].values[0])
|
| 712 |
+
top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Quantized Model"].values[0]
|
| 713 |
+
|
| 714 |
+
fname = os.path.basename(getattr(path, "name", path))
|
| 715 |
+
record = {"Image": fname, "Top-1": top1, "Latency (ms)": latency}
|
| 716 |
+
if fname in labels_map:
|
| 717 |
+
record["Label"] = labels_map[fname]
|
| 718 |
+
record["Correct"] = labels_map[fname] == top1
|
| 719 |
+
per_image.append(record)
|
| 720 |
+
latencies.append(latency)
|
| 721 |
+
|
| 722 |
+
per_image_df = pd.DataFrame(per_image)
|
| 723 |
+
summary = {
|
| 724 |
+
"count": len(latencies),
|
| 725 |
+
"mean_latency": float(np.mean(latencies)),
|
| 726 |
+
"median_latency": float(np.median(latencies)),
|
| 727 |
+
"max_latency": float(np.max(latencies)),
|
| 728 |
+
}
|
| 729 |
+
if "Correct" in per_image_df.columns:
|
| 730 |
+
summary["accuracy"] = float(per_image_df["Correct"].mean())
|
| 731 |
+
|
| 732 |
+
summary_df = pd.DataFrame({"Metric": list(summary.keys()), "Value": list(summary.values())})
|
| 733 |
+
return per_image_df, summary_df
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
def run_sweep(img, model_name, device_choice, experiments_json=None):
|
| 737 |
+
if img is None:
|
| 738 |
+
return pd.DataFrame(), pd.DataFrame()
|
| 739 |
+
default_experiments = [
|
| 740 |
+
{"mode": "prune", "amount": 0.2, "method": "structured"},
|
| 741 |
+
{"mode": "prune", "amount": 0.5, "method": "structured"},
|
| 742 |
+
{"mode": "quant", "q_type": "dynamic"},
|
| 743 |
+
{"mode": "quant", "q_type": "fp16"},
|
| 744 |
+
]
|
| 745 |
+
try:
|
| 746 |
+
experiments = json.loads(experiments_json) if experiments_json else default_experiments
|
| 747 |
+
except Exception:
|
| 748 |
+
experiments = default_experiments
|
| 749 |
+
|
| 750 |
+
rows = []
|
| 751 |
+
for exp in experiments:
|
| 752 |
+
if exp.get("mode") == "prune":
|
| 753 |
+
metrics, _, _, _ = run_pruned(
|
| 754 |
+
img,
|
| 755 |
+
model_name,
|
| 756 |
+
exp.get("method", "structured"),
|
| 757 |
+
exp.get("amount", 0.4),
|
| 758 |
+
device_choice=device_choice,
|
| 759 |
+
export_state=False,
|
| 760 |
+
)
|
| 761 |
+
latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Pruned Model"].values[0])
|
| 762 |
+
size = float(metrics.loc[metrics["Metric"] == "Model Size (MB)", "Pruned Model"].values[0])
|
| 763 |
+
top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Pruned Model"].values[0]
|
| 764 |
+
rows.append({"mode": "prune", "amount": exp.get("amount"), "latency": latency, "size": size, "top1": top1})
|
| 765 |
+
else:
|
| 766 |
+
metrics, _, _ = run_quantized(
|
| 767 |
+
img,
|
| 768 |
+
model_name,
|
| 769 |
+
exp.get("q_type", "dynamic"),
|
| 770 |
+
device_choice=device_choice,
|
| 771 |
+
export_state=False,
|
| 772 |
+
)
|
| 773 |
+
latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Quantized Model"].values[0])
|
| 774 |
+
size = float(metrics.loc[metrics["Metric"] == "Model Size (MB)", "Quantized Model"].values[0])
|
| 775 |
+
top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Quantized Model"].values[0]
|
| 776 |
+
rows.append({"mode": "quant", "q_type": exp.get("q_type"), "latency": latency, "size": size, "top1": top1})
|
| 777 |
+
|
| 778 |
+
df = pd.DataFrame(rows)
|
| 779 |
+
pareto_df = df.rename(columns={"latency": "Latency (ms)", "size": "Model Size (MB)", "top1": "Top-1"})
|
| 780 |
+
return df, pareto_df
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def fastapi_snippet():
|
| 784 |
+
return """
|
| 785 |
+
from fastapi import FastAPI, UploadFile
|
| 786 |
+
from PIL import Image
|
| 787 |
+
import io, torch, timm
|
| 788 |
+
from app import get_transform, get_fp32_model, run_inference, select_device
|
| 789 |
+
|
| 790 |
+
app = FastAPI()
|
| 791 |
+
MODEL = "resnet50"
|
| 792 |
+
DEVICE = select_device("auto")
|
| 793 |
+
MODEL_OBJ = get_fp32_model(MODEL).to(DEVICE)
|
| 794 |
+
TRANSFORM = get_transform(MODEL)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@app.post('/predict')
|
| 798 |
+
async def predict(file: UploadFile):
|
| 799 |
+
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
|
| 800 |
+
results, latency = run_inference(MODEL_OBJ, img, DEVICE, TRANSFORM)
|
| 801 |
+
return {"top1": results[0][0], "confidence": results[0][1], "latency_ms": latency}
|
| 802 |
+
"""
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def dockerfile_snippet():
|
| 806 |
+
return """
|
| 807 |
+
FROM python:3.10-slim
|
| 808 |
+
WORKDIR /app
|
| 809 |
+
COPY requirements.txt .
|
| 810 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 811 |
+
COPY . .
|
| 812 |
+
CMD ["python", "app.py", "--cli", "--image", "examples/cat.jpg"]
|
| 813 |
+
"""
|
| 814 |
|
| 815 |
|
| 816 |
# ---------------------------------------------
|
| 817 |
# GRADIO UI
|
| 818 |
# ---------------------------------------------
|
| 819 |
+
examples = [["examples/cat.jpg"], ["examples/dog.jpg"], ["examples/bird.jpg"], ["examples/car.jpg"], ["examples/elephant.jpg"]]
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
def create_demo():
|
| 823 |
+
with gr.Blocks() as demo:
|
| 824 |
+
gr.Markdown("# π§ Model Optimization Lab β Compare, Export, Benchmark")
|
| 825 |
+
|
| 826 |
+
device_opts = ["auto", "cpu"]
|
| 827 |
+
if torch.cuda.is_available():
|
| 828 |
+
device_opts.append("cuda")
|
| 829 |
+
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 830 |
+
device_opts.append("mps")
|
| 831 |
+
preset_opts = list(PRESETS.keys()) + ["custom"]
|
| 832 |
+
|
| 833 |
+
with gr.Tabs():
|
| 834 |
+
# ---- PRUNING TAB ----
|
| 835 |
+
with gr.Tab("Pruning"):
|
| 836 |
+
with gr.Row():
|
| 837 |
+
with gr.Column():
|
| 838 |
+
img_p = gr.Image(label="Upload Image")
|
| 839 |
+
model_p = gr.Dropdown(MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model")
|
| 840 |
+
preset_p = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
|
| 841 |
+
method_p = gr.Dropdown(["unstructured", "structured"], value="structured", label="Pruning Method")
|
| 842 |
+
amount_p = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.4, label="Pruning Amount")
|
| 843 |
+
device_p = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
|
| 844 |
+
channels_last_p = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
|
| 845 |
+
amp_p = gr.Checkbox(label="Mixed precision (AMP)", value=True)
|
| 846 |
+
compile_p = gr.Checkbox(label="Torch compile (PyTorch 2)")
|
| 847 |
+
export_ts_p = gr.Checkbox(label="Export TorchScript")
|
| 848 |
+
export_onnx_p = gr.Checkbox(label="Export ONNX")
|
| 849 |
+
export_report_p = gr.Checkbox(label="Export JSON report", value=True)
|
| 850 |
+
btn_p = gr.Button("Run Pruned Model")
|
| 851 |
+
gr.Examples(examples=examples, inputs=img_p)
|
| 852 |
+
|
| 853 |
+
with gr.Column():
|
| 854 |
+
metrics_p = gr.Dataframe(label="π Comparison Metrics", headers=["Metric", "Original Model", "Pruned Model"])
|
| 855 |
+
chart_p = gr.Plot(label="π― Top-5 Predictions Comparison")
|
| 856 |
+
sparsity_p = gr.Dataframe(label="Layer sparsity (%)")
|
| 857 |
+
downloads_p = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
|
| 858 |
+
|
| 859 |
+
btn_p.click(
|
| 860 |
+
fn=run_pruned,
|
| 861 |
+
inputs=[
|
| 862 |
+
img_p,
|
| 863 |
+
model_p,
|
| 864 |
+
method_p,
|
| 865 |
+
amount_p,
|
| 866 |
+
device_p,
|
| 867 |
+
channels_last_p,
|
| 868 |
+
compile_p,
|
| 869 |
+
amp_p,
|
| 870 |
+
export_ts_p,
|
| 871 |
+
export_onnx_p,
|
| 872 |
+
export_report_p,
|
| 873 |
+
gr.State(True),
|
| 874 |
+
preset_p,
|
| 875 |
+
],
|
| 876 |
+
outputs=[metrics_p, chart_p, sparsity_p, downloads_p],
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# ---- QUANTIZATION TAB ----
|
| 880 |
+
with gr.Tab("Quantization"):
|
| 881 |
+
with gr.Row():
|
| 882 |
+
with gr.Column():
|
| 883 |
+
img_q = gr.Image(label="Upload Image")
|
| 884 |
+
model_q = gr.Dropdown(MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model")
|
| 885 |
+
preset_q = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
|
| 886 |
+
q_type = gr.Dropdown(["dynamic", "weight_only", "fp16"], value="dynamic", label="Quantization Type")
|
| 887 |
+
device_q = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
|
| 888 |
+
channels_last_q = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
|
| 889 |
+
amp_q = gr.Checkbox(label="Mixed precision (AMP)", value=True)
|
| 890 |
+
compile_q = gr.Checkbox(label="Torch compile (PyTorch 2)")
|
| 891 |
+
export_ts_q = gr.Checkbox(label="Export TorchScript")
|
| 892 |
+
export_onnx_q = gr.Checkbox(label="Export ONNX")
|
| 893 |
+
export_report_q = gr.Checkbox(label="Export JSON report", value=True)
|
| 894 |
+
btn_q = gr.Button("Run Quantized Model")
|
| 895 |
+
gr.Examples(examples=examples, inputs=img_q)
|
| 896 |
+
|
| 897 |
+
with gr.Column():
|
| 898 |
+
metrics_q = gr.Dataframe(label="π Comparison Metrics", headers=["Metric", "Original Model", "Quantized Model"])
|
| 899 |
+
chart_q = gr.Plot(label="π― Top-5 Predictions Comparison")
|
| 900 |
+
downloads_q = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
|
| 901 |
+
|
| 902 |
+
btn_q.click(
|
| 903 |
+
fn=run_quantized,
|
| 904 |
+
inputs=[
|
| 905 |
+
img_q,
|
| 906 |
+
model_q,
|
| 907 |
+
q_type,
|
| 908 |
+
device_q,
|
| 909 |
+
channels_last_q,
|
| 910 |
+
compile_q,
|
| 911 |
+
amp_q,
|
| 912 |
+
export_ts_q,
|
| 913 |
+
export_onnx_q,
|
| 914 |
+
export_report_q,
|
| 915 |
+
gr.State(True),
|
| 916 |
+
preset_q,
|
| 917 |
+
],
|
| 918 |
+
outputs=[metrics_q, chart_q, downloads_q],
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
return demo
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def main():
|
| 925 |
+
parser = argparse.ArgumentParser(description="Model optimization lab")
|
| 926 |
+
parser.add_argument("--cli", action="store_true", help="Run in CLI mode (no UI)")
|
| 927 |
+
parser.add_argument("--mode", choices=["prune", "quant"], default="prune", help="Optimization mode in CLI")
|
| 928 |
+
parser.add_argument("--image", type=str, help="Path to an image for CLI mode")
|
| 929 |
+
parser.add_argument("--model", type=str, default=MODEL_OPTIONS[0], choices=MODEL_OPTIONS)
|
| 930 |
+
parser.add_argument("--device", type=str, default="auto")
|
| 931 |
+
args = parser.parse_args()
|
| 932 |
+
|
| 933 |
+
if args.cli:
|
| 934 |
+
if not args.image:
|
| 935 |
+
raise SystemExit("--image is required in CLI mode")
|
| 936 |
+
img = Image.open(args.image)
|
| 937 |
+
if args.mode == "prune":
|
| 938 |
+
metrics, _, downloads = run_pruned(img, args.model, "structured", 0.4, device_choice=args.device)
|
| 939 |
+
else:
|
| 940 |
+
metrics, _, downloads = run_quantized(img, args.model, "dynamic", device_choice=args.device)
|
| 941 |
+
print(metrics)
|
| 942 |
+
print("Exports:", downloads)
|
| 943 |
+
return
|
| 944 |
+
|
| 945 |
+
demo = create_demo()
|
| 946 |
+
demo.launch()
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
if __name__ == "__main__":
|
| 950 |
+
main()
|