shriarul5273 commited on
Commit
46b3a11
Β·
1 Parent(s): 7972153

Refactor app.py for model optimization: add pruning and quantization options.

Browse files
Files changed (4) hide show
  1. .github/workflows/huggingface.yml +25 -0
  2. .gitignore +4 -0
  3. README.md +106 -1
  4. app.py +819 -180
.github/workflows/huggingface.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v5
14
+ with:
15
+ fetch-depth: 0
16
+ - name: Add remote
17
+ env:
18
+ HF: ${{secrets.HF_TOKEN }}
19
+ HFUSER: ${{secrets.HFUSER }}
20
+ run: git remote add space https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/model-optimization-lab
21
+ - name: Push to huggingface hub
22
+ env:
23
+ HF: ${{ secrets.HF_TOKEN}}
24
+ HFUSER: ${{secrets.HFUSER }}
25
+ run: git push --force https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/model-optimization-lab main
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pth
2
+ exports/
3
+ __pycache__/
4
+ *.pyc
README.md CHANGED
@@ -1 +1,106 @@
1
- # model-optimization-lab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Model Optimization Lab
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Model Optimization Lab
13
+
14
+ Interactive Gradio playground for comparing pruning and quantization on ImageNet-classification backbones. Upload any image and observe how latency, confidence, and model size change when applying different compression recipes. Pretrained weights are loaded by default; set `MODEL_OPT_PRETRAINED=0` if you want random initialization for experimentation.
15
+
16
+ ## Features
17
+ - Baseline FP32 inference using cached backbones (ResNet-50, MobileNetV3, EfficientNet-B0, etc.).
18
+ - Pruning tab: structured/unstructured pruning with configurable sparsity and size/latency comparison.
19
+ - Quantization tab: dynamic, weight-only INT8, and FP16 passes with CPU-safe fallbacks for unsupported kernels.
20
+ - Automated metric tables and Top-5 bar charts to visualize confidence shifts between optimized variants.
21
+
22
+ ## Requirements
23
+ - Python 3.9+
24
+ - PyTorch with CPU support (GPU optional but recommended for FP16 experiments).
25
+ - The packages listed in `requirements.txt` or installed via `pip install -r requirements.txt` (create the file if missing with entries like `torch`, `timm`, `gradio`, `pandas`, `torchvision`).
26
+
27
+ ## Quick Start
28
+ 1. Clone the repository:
29
+ ```bash
30
+ git clone https://github.com/shriarul5273/model-optimization-lab.git
31
+ cd model-optimization-lab
32
+ ```
33
+ 2. Create and activate a virtual environment (optional but recommended).
34
+ 3. Install dependencies:
35
+ ```bash
36
+ pip install -r requirements.txt
37
+ ```
38
+ 4. Launch the Gradio app:
39
+ ```bash
40
+ python app.py
41
+ ```
42
+ 5. Open the local Gradio URL (printed in the terminal) in your browser.
43
+
44
+ ## Using the App
45
+ 1. **Upload an image** or pick one of the provided examples.
46
+ 2. Choose the **Base Model** dropdown (ResNet-50, MobileNetV3-Large, EfficientNet-B0, ConvNeXt-Tiny, ViT-B/16, RegNetY-016, EfficientNet-Lite0).
47
+ 3. Pick a **Hardware Preset** or keep `custom`:
48
+ - Edge CPU β€” CPU, channels-last off, dynamic quantization, 30% pruning.
49
+ - Datacenter GPU β€” CUDA, channels-last on, `torch.compile`, FP16 quantization, 20% pruning.
50
+ - Apple MPS β€” MPS, FP16 quantization, 20% pruning.
51
+ 4. Pick a tab and set options, then click **Run**.
52
+
53
+ ### Pruning tab options
54
+ - `Pruning Method`: `structured` (LN-structured) or `unstructured` (L1). Applied to Conv2d weights before export.
55
+ - `Pruning Amount`: 0.1–0.9 sparsity. Higher numbers zero more weights; latency impact depends on kernel support.
56
+ - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU. Channels-last is only honored on CUDA.
57
+ - `Channels-last input (CUDA)`: Converts tensors to channels-last for better CUDA kernel throughput.
58
+ - `Mixed precision (AMP)`: Enables CUDA autocast for FP16/FP32 mixes.
59
+ - `Torch compile (PyTorch 2)`: Wraps the pruned model in `torch.compile` when available.
60
+ - Exports: TorchScript (`pruned_model.ts`), ONNX (`pruned_model.onnx`), JSON report, always saves `pruned_state_dict.pth`.
61
+ - Outputs: comparison metrics, Top-5 bar chart, per-layer sparsity table, download list of artifacts.
62
+
63
+ ### Quantization tab options
64
+ - `Quantization Type`: `dynamic`/`weight_only` (INT8 linear layers on CPU), or `fp16` (casts model to half precision).
65
+ - `Device`: `auto` picks CUDA β†’ MPS β†’ CPU; dynamic/weight-only runs force CPU execution for kernel support.
66
+ - `Channels-last input (CUDA)`: Same as pruning; ignored on CPU.
67
+ - `Mixed precision (AMP)`: Applies CUDA autocast to the quantized forward pass.
68
+ - `Torch compile (PyTorch 2)`: Compiles the quantized model when available.
69
+ - Exports: TorchScript (`quantized_model.ts`), ONNX (`quantized_model.onnx`), JSON report, always saves `quantized_state_dict.pth`.
70
+ - Outputs: comparison metrics, Top-5 bar chart, download list of artifacts.
71
+
72
+ ### What gets exported
73
+ - Artifacts are written to `exports/`. JSON reports include the chosen options, metrics, and Top-5 results for both the baseline and optimized variants.
74
+ - TorchScript/ONNX exports run best on CPU inputs; failures are logged to the console and skipped.
75
+ - State dicts are always saved for reproducibility; disable or prune them manually if you are embedding this module elsewhere.
76
+
77
+ ### Output Interpreting Tips
78
+ - **Top-1 Prediction**: Labels come from ImageNet synsets, so some entries include multiple comma-separated synonyms (e.g., `chambered nautilus, pearly nautilus`).
79
+ - **Latency (ms)**: Includes the reported inference latency for each pass. Large numbers for quantized runs may indicate preprocessing overhead rather than faster model executionβ€”see [Performance Notes](#performance-notes).
80
+ - **Model Size (MB)**: Serialized state dictionary size after saving to disk.
81
+
82
+ ## Performance Notes
83
+ - The current quantization pipeline rebuilds the optimized model on each request, so the reported latency includes setup time. Reusing pre-quantized instances will yield more realistic numbers.
84
+ - Dynamic and weight-only quantization only affect linear layers; ResNet-50 is dominated by convolution blocks that remain FP32, so speedups are modest on CPU. Unsupported static INT8 kernels automatically fall back to dynamic quantization.
85
+ - PyTorch default quantization backend may fall back to `qnnpack` on CPU. For x86 systems, set `torch.backends.quantized.engine = "fbgemm"` before quantization for best results.
86
+ - FP16 inference is beneficial on GPUs. On CPU, PyTorch often casts half tensors back to float32, introducing overhead.
87
+
88
+ ## Extending the Lab
89
+ - Swap in different architectures by changing the `timm.create_model` call in `app.py`.
90
+ - Add calibration data and static INT8 quantization to include convolution layers.
91
+ - Cache optimized models to avoid recomputation across requests.
92
+ - Integrate evaluation datasets to quantify accuracy drop beyond top-1 confidence.
93
+
94
+ ## CLI Mode
95
+ - Run without the UI: `python app.py --cli --image path/to/img.jpg --mode prune --model resnet50 --device auto`
96
+ - Modes: `--mode prune` (structured pruning @ 0.4 sparsity) or `--mode quant` (dynamic quantization). Both emit the metrics table and export artifacts list.
97
+ - Devices: `auto` chooses CUDA β†’ MPS β†’ CPU based on availability; `cpu`/`cuda`/`mps` force a device. Dynamic/weight-only quantization forces CPU for kernel support even if GPU is requested.
98
+ - Models: any entry from `MODEL_OPTIONS` in `app.py`.
99
+
100
+ ## Troubleshooting
101
+ - **Slow downloads**: The first run downloads pretrained weights (~100 MB). Subsequent runs use cached files.
102
+ - **CUDA errors**: Ensure the correct CUDA-enabled PyTorch build is installed if you intend to run on GPU.
103
+ - **Quantized model larger than expected**: The state dictionary includes dequantized tensors for some paths (e.g., dynamic quantization). Consider TorchScript or ONNX export for compact deployment artifacts.
104
+
105
+ ## License
106
+ This project inherits the default license of the repository. Replace or update this section if you add a specific license.
app.py CHANGED
@@ -1,33 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  import torch.nn.utils.prune as prune
4
- import timm
5
- import gradio as gr
6
- from torchvision import transforms
7
  from PIL import Image
8
- import time
9
- import os
10
- import pandas as pd
11
 
12
 
13
  # ---------------------------------------------
14
- # Base FP32 Model (Loaded Once)
15
  # ---------------------------------------------
16
- fp32_model = timm.create_model("resnet50", pretrained=True)
17
- fp32_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  # ---------------------------------------------
21
  # Image Preprocess
22
  # ---------------------------------------------
23
- transform = transforms.Compose([
24
- transforms.Resize((224, 224)),
25
- transforms.ToTensor(),
26
- transforms.Normalize([0.485, 0.456, 0.406],
27
- [0.229, 0.224, 0.225])
28
- ])
29
-
30
- # Get ImageNet labels - using class descriptions
31
  imagenet_info = timm.data.ImageNetInfo()
32
  labels = [imagenet_info.index_to_description(i) for i in range(1000)]
33
 
@@ -40,13 +134,15 @@ def apply_pruning(model, amount=0.5, method="unstructured"):
40
 
41
  for module in model.modules():
42
  if isinstance(module, nn.Conv2d):
43
-
44
  if method == "unstructured":
45
  prune.l1_unstructured(module, name="weight", amount=amount)
46
-
47
  elif method == "structured":
48
  prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
49
 
 
 
 
 
50
  return model
51
 
52
 
@@ -54,52 +150,201 @@ def apply_pruning(model, amount=0.5, method="unstructured"):
54
  # QUANTIZATION FUNCTION (dynamic)
55
  # ---------------------------------------------
56
  def apply_quantization(model, q_type="dynamic"):
57
- if q_type == "dynamic":
58
- return torch.ao.quantization.quantize_dynamic(
59
- model, {nn.Linear}, dtype=torch.qint8
60
- )
 
 
61
 
62
- elif q_type == "weight_only":
63
- model_int8 = model
64
- for name, module in model_int8.named_modules():
65
- if isinstance(module, nn.Linear):
66
- module.weight.data = torch.quantize_per_tensor(
67
- module.weight.data, scale=0.1, zero_point=0, dtype=torch.qint8
68
- ).dequantize()
69
- return model_int8
70
 
71
- elif q_type == "fp16":
72
- return model.half().eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  # ---------------------------------------------
78
  # Inference Function (shared)
79
  # ---------------------------------------------
80
- def run_inference(model, image):
81
  print(f" run_inference called, image type: {type(image)}")
82
- if image is None:
83
- raise ValueError("No image provided")
84
-
85
- if not isinstance(image, Image.Image):
86
- print(" Converting numpy array to PIL Image")
87
- image = Image.fromarray(image.astype('uint8'))
88
 
89
- print(" Applying transforms...")
90
- img = transform(image).unsqueeze(0)
 
 
 
91
 
92
  if next(model.parameters()).dtype == torch.float16:
93
  img = img.half()
94
 
 
 
 
 
95
  print(" Running model inference...")
96
  start = time.time()
97
- with torch.no_grad():
 
98
  out = model(img)
99
  latency = (time.time() - start) * 1000
100
 
101
- print(" Processing results...")
102
- prob = torch.softmax(out, dim=1)[0]
103
  top5_prob, top5_idx = torch.topk(prob, 5)
104
 
105
  results = [(labels[i], float(top5_prob[j])) for j, i in enumerate(top5_idx)]
@@ -107,47 +352,92 @@ def run_inference(model, image):
107
  return results, latency
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # ---------------------------------------------
111
  # Gradio Functions With Options
112
  # ---------------------------------------------
113
- def run_pruned(img, method, amount):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  print("\n=== RUN PRUNED CALLED ===")
115
- print(f"Image type: {type(img)}, Method: {method}, Amount: {amount}")
116
-
117
  if img is None:
118
  print("ERROR: Image is None")
119
- return {"Metric": ["Error"], "Original Model": ["No image uploaded"], "Pruned Model": [""]}, {"Class": [], "Original": [], "Pruned": []}
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Run original model
122
  print("Running original model...")
123
- results_orig, latency_orig = run_inference(fp32_model, img)
 
124
  print(f"Original model done. Latency: {latency_orig:.2f}ms")
125
-
126
  # Run pruned model
127
  print("Creating fresh model...")
128
- fresh_model = timm.create_model("resnet50", pretrained=True).eval()
129
  print("Applying pruning...")
130
  pruned_model = apply_pruning(fresh_model, amount=float(amount), method=method)
 
131
  print("Running pruned model...")
132
- results_pruned, latency_pruned = run_inference(pruned_model, img)
133
  print(f"Pruned model done. Latency: {latency_pruned:.2f}ms")
134
 
135
- # Model sizes
136
- print("Saving models...")
137
- torch.save(fp32_model.state_dict(), "fp32_model.pth")
138
-
139
- # Make pruning permanent by removing the mask (this shows the actual reduced size)
140
- for module in pruned_model.modules():
141
- if isinstance(module, nn.Conv2d) and hasattr(module, 'weight_orig'):
142
- prune.remove(module, 'weight')
143
-
144
- torch.save(pruned_model.state_dict(), "pruned_model.pth")
145
- size_orig = os.path.getsize("fp32_model.pth") / 1e6
146
- size_pruned = os.path.getsize("pruned_model.pth") / 1e6
147
  print(f"Model sizes - Original: {size_orig:.2f}MB, Pruned: {size_pruned:.2f}MB")
148
 
149
  # Comparison metrics - as DataFrame for Gradio
150
- print("Creating metrics dataframe...")
151
  metrics_df = pd.DataFrame({
152
  "Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
153
  "Original Model": [
@@ -163,45 +453,118 @@ def run_pruned(img, method, amount):
163
  f"{size_pruned:.2f}"
164
  ]
165
  })
166
-
167
- # Top-5 predictions chart data - as DataFrame for BarPlot
168
- print("Preparing chart data...")
169
- chart_df = pd.DataFrame({
170
- "Class": [results_orig[i][0] for i in range(5)],
171
- "Original": [results_orig[i][1] for i in range(5)],
172
- "Pruned": [results_pruned[i][1] for i in range(5)]
173
- })
174
 
175
- print("=== RUN PRUNED COMPLETE ===")
176
- return metrics_df, chart_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
-
179
- def run_quantized(img, q_type):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  print("\n=== RUN QUANTIZED CALLED ===")
181
- print(f"Image type: {type(img)}, Q-type: {q_type}")
182
-
183
  if img is None:
184
  print("ERROR: Image is None")
185
- return {"Metric": ["Error"], "Original Model": ["No image uploaded"], "Quantized Model": [""]}, {"Class": [], "Original": [], "Quantized": []}
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # Run original model
188
  print("Running original model...")
189
- results_orig, latency_orig = run_inference(fp32_model, img)
 
190
  print(f"Original model done. Latency: {latency_orig:.2f}ms")
191
-
192
  # Run quantized model
193
- fresh_model = timm.create_model("resnet50", pretrained=True).eval()
194
  quant_model = apply_quantization(fresh_model, q_type)
195
- results_quant, latency_quant = run_inference(quant_model, img)
 
196
 
197
- # Model sizes
198
- torch.save(fp32_model.state_dict(), "fp32_model.pth")
199
- torch.save(quant_model.state_dict(), "quantized_model.pth")
200
- size_orig = os.path.getsize("fp32_model.pth") / 1e6
201
- size_quant = os.path.getsize("quantized_model.pth") / 1e6
202
 
203
- # Comparison metrics - as DataFrame for Gradio
204
- print("Creating metrics dataframe...")
205
  metrics_df = pd.DataFrame({
206
  "Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
207
  "Original Model": [
@@ -217,95 +580,371 @@ def run_quantized(img, q_type):
217
  f"{size_quant:.2f}"
218
  ]
219
  })
220
-
221
- # Top-5 predictions chart data - as DataFrame for BarPlot
222
- print("Preparing chart data...")
223
- chart_df = pd.DataFrame({
224
- "Class": [results_orig[i][0] for i in range(5)],
225
- "Original": [results_orig[i][1] for i in range(5)],
226
- "Quantized": [results_quant[i][1] for i in range(5)]
227
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  print("=== RUN QUANTIZED COMPLETE ===")
230
- return metrics_df, chart_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  # ---------------------------------------------
234
  # GRADIO UI
235
  # ---------------------------------------------
236
- # Example images
237
- examples = [
238
- ["examples/cat.jpg"],
239
- ["examples/dog.jpg"],
240
- ["examples/bird.jpg"],
241
- ["examples/car.jpg"],
242
- ["examples/elephant.jpg"]
243
- ]
244
-
245
- with gr.Blocks() as demo:
246
- gr.Markdown("# 🧠 ResNet50 Optimization β€” Select Options to Compare")
247
-
248
- with gr.Tabs():
249
-
250
- # ---- PRUNING TAB ----
251
- with gr.Tab("Pruning"):
252
- with gr.Row():
253
- with gr.Column():
254
- img_p = gr.Image(label="Upload Image")
255
-
256
- method_p = gr.Dropdown(
257
- ["unstructured", "structured"],
258
- value="structured",
259
- label="Pruning Method"
260
- )
261
-
262
- amount_p = gr.Slider(
263
- minimum=0.1, maximum=0.9, step=0.1, value=0.4,
264
- label="Pruning Amount"
265
- )
266
-
267
- btn_p = gr.Button("Run Pruned Model")
268
-
269
- gr.Examples(examples=examples, inputs=img_p)
270
-
271
- with gr.Column():
272
- metrics_p = gr.Dataframe(label="πŸ“Š Comparison Metrics", headers=["Metric", "Original Model", "Pruned Model"])
273
- chart_p = gr.BarPlot(
274
- label="🎯 Top-5 Predictions Comparison",
275
- x="Class",
276
- y_title="Confidence",
277
- height=400
278
- )
279
-
280
- btn_p.click(fn=run_pruned, inputs=[img_p, method_p, amount_p], outputs=[metrics_p, chart_p])
281
-
282
-
283
- # ---- QUANTIZATION TAB ----
284
- with gr.Tab("Quantization"):
285
- with gr.Row():
286
- with gr.Column():
287
- img_q = gr.Image(label="Upload Image")
288
-
289
- q_type = gr.Dropdown(
290
- ["dynamic", "weight_only", "fp16"],
291
- value="dynamic",
292
- label="Quantization Type"
293
- )
294
-
295
- btn_q = gr.Button("Run Quantized Model")
296
-
297
- gr.Examples(examples=examples, inputs=img_q)
298
-
299
- with gr.Column():
300
- metrics_q = gr.Dataframe(label="πŸ“Š Comparison Metrics", headers=["Metric", "Original Model", "Quantized Model"])
301
- chart_q = gr.BarPlot(
302
- label="🎯 Top-5 Predictions Comparison",
303
- x="Class",
304
- y_title="Confidence",
305
- height=400
306
- )
307
-
308
- btn_q.click(fn=run_quantized, inputs=[img_q, q_type], outputs=[metrics_q, chart_q])
309
-
310
-
311
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import json
4
+ import os
5
+ import time
6
+ import zipfile
7
+ from pathlib import Path
8
+ from tempfile import TemporaryDirectory
9
+
10
+ import matplotlib.pyplot as plt
11
+ import gradio as gr
12
+ import numpy as np
13
+ import pandas as pd
14
+ import timm
15
  import torch
16
  import torch.nn as nn
17
+ import torch.nn.functional as F
18
  import torch.nn.utils.prune as prune
 
 
 
19
  from PIL import Image
20
+ from torchvision import transforms
 
 
21
 
22
 
23
  # ---------------------------------------------
24
+ # Base Model Registry / Defaults
25
  # ---------------------------------------------
26
+ MODEL_OPTIONS = [
27
+ "resnet50",
28
+ "mobilenetv3_large_100",
29
+ "efficientnet_b0",
30
+ "convnext_tiny",
31
+ "vit_base_patch16_224",
32
+ "regnety_016",
33
+ "efficientnet_lite0",
34
+ ]
35
+ PRETRAINED_DEFAULT = os.getenv("MODEL_OPT_PRETRAINED", "1") == "1"
36
+ _PRETRAINED_DISABLED = False
37
+ _PRETRAINED_WARNED = False
38
+
39
+ PRESETS = {
40
+ "Edge CPU": {
41
+ "device": "cpu",
42
+ "channels_last": False,
43
+ "compile": False,
44
+ "quant": "dynamic",
45
+ "prune_amount": 0.3,
46
+ },
47
+ "Datacenter GPU": {
48
+ "device": "cuda",
49
+ "channels_last": True,
50
+ "compile": True,
51
+ "quant": "fp16",
52
+ "prune_amount": 0.2,
53
+ },
54
+ "Apple MPS": {
55
+ "device": "mps",
56
+ "channels_last": False,
57
+ "compile": False,
58
+ "quant": "fp16",
59
+ "prune_amount": 0.2,
60
+ },
61
+ }
62
+
63
+ _MODEL_CACHE: dict[str, torch.nn.Module] = {}
64
+ _TRANSFORM_CACHE: dict[str, transforms.Compose] = {}
65
+
66
+
67
+ def select_device(device_str: str) -> torch.device:
68
+ """Return a valid torch.device based on user selection."""
69
+ device_str = (device_str or "auto").lower()
70
+ if device_str == "cuda" and torch.cuda.is_available():
71
+ return torch.device("cuda")
72
+ if device_str == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
73
+ return torch.device("mps")
74
+ if device_str == "cpu":
75
+ return torch.device("cpu")
76
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+
78
+
79
+ def get_transform(model_name: str):
80
+ if model_name in _TRANSFORM_CACHE:
81
+ return _TRANSFORM_CACHE[model_name]
82
+
83
+ model = get_fp32_model(model_name)
84
+
85
+ if hasattr(timm.data, "resolve_model_data_config"):
86
+ data_cfg = timm.data.resolve_model_data_config(model)
87
+ else:
88
+ # Fallback for older timm versions
89
+ data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
90
+
91
+ _TRANSFORM_CACHE[model_name] = timm.data.create_transform(**data_cfg)
92
+ return _TRANSFORM_CACHE[model_name]
93
+
94
+
95
+ def get_fp32_model(model_name: str):
96
+ if model_name not in _MODEL_CACHE:
97
+ global _PRETRAINED_DISABLED, _PRETRAINED_WARNED
98
+ use_pretrained = PRETRAINED_DEFAULT and not _PRETRAINED_DISABLED
99
+ if not use_pretrained and not _PRETRAINED_WARNED:
100
+ print("INFO: MODEL_OPT_PRETRAINED disabled; using randomly initialized weights.")
101
+ _PRETRAINED_WARNED = True
102
+ try:
103
+ loaded = timm.create_model(model_name, pretrained=use_pretrained)
104
+ except Exception as exc:
105
+ print(f"Warning: pretrained weights unavailable ({exc}); using random init for {model_name}")
106
+ _PRETRAINED_DISABLED = True
107
+ loaded = timm.create_model(model_name, pretrained=False)
108
+ loaded.eval()
109
+ _MODEL_CACHE[model_name] = loaded
110
+ return _MODEL_CACHE[model_name]
111
+
112
+
113
+ def clone_model(model_name: str):
114
+ """Create a fresh model loaded from the cached FP32 weights to avoid re-downloads."""
115
+ base = get_fp32_model(model_name)
116
+ fresh = timm.create_model(model_name, pretrained=False)
117
+ fresh.load_state_dict(base.state_dict())
118
+ fresh.eval()
119
+ return fresh
120
 
121
 
122
  # ---------------------------------------------
123
  # Image Preprocess
124
  # ---------------------------------------------
 
 
 
 
 
 
 
 
125
  imagenet_info = timm.data.ImageNetInfo()
126
  labels = [imagenet_info.index_to_description(i) for i in range(1000)]
127
 
 
134
 
135
  for module in model.modules():
136
  if isinstance(module, nn.Conv2d):
 
137
  if method == "unstructured":
138
  prune.l1_unstructured(module, name="weight", amount=amount)
 
139
  elif method == "structured":
140
  prune.ln_structured(module, name="weight", amount=amount, n=2, dim=0)
141
 
142
+ # Make pruning permanent to reflect true sparsity/size
143
+ for module in model.modules():
144
+ if isinstance(module, nn.Conv2d) and hasattr(module, "weight_orig"):
145
+ prune.remove(module, "weight")
146
  return model
147
 
148
 
 
150
  # QUANTIZATION FUNCTION (dynamic)
151
  # ---------------------------------------------
152
  def apply_quantization(model, q_type="dynamic"):
153
+ q_type = q_type or "dynamic"
154
+ if q_type in {"dynamic", "weight_only"}: # dynamic quantization is weight-only by default
155
+ return torch.ao.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
156
+ if q_type == "fp16":
157
+ return model.half().eval()
158
+ return model
159
 
 
 
 
 
 
 
 
 
160
 
161
+ def compute_sparsity(model: nn.Module) -> pd.DataFrame:
162
+ rows = []
163
+ for name, module in model.named_modules():
164
+ if hasattr(module, "weight"):
165
+ weight = module.weight.detach()
166
+ total = weight.numel()
167
+ zeros = (weight == 0).sum().item()
168
+ sparsity = 100.0 * zeros / max(total, 1)
169
+ rows.append({"Layer": name, "Params": total, "Sparsity %": round(sparsity, 2)})
170
+ return pd.DataFrame(rows)
171
+
172
+
173
+ def _macs_conv2d(inp, module: nn.Conv2d, out):
174
+ # inp: (N, C_in, H, W), out: (N, C_out, H_out, W_out)
175
+ batch, c_in, h, w = inp.shape
176
+ _, c_out, h_out, w_out = out.shape
177
+ kernel_ops = module.kernel_size[0] * module.kernel_size[1] * (c_in / module.groups)
178
+ return batch * h_out * w_out * c_out * kernel_ops
179
+
180
+
181
+ def _macs_linear(inp, module: nn.Linear, out):
182
+ batch = inp.shape[0]
183
+ return batch * module.in_features * module.out_features
184
+
185
+
186
+ def layer_profile(model: nn.Module, sample_input: torch.Tensor) -> pd.DataFrame:
187
+ """Collect per-layer params, MACs, and forward time (single run)."""
188
+ rows = []
189
+ handles = []
190
+ start_times = {}
191
+
192
+ def pre_hook(name):
193
+ def _pre(mod, inp):
194
+ start_times[name] = time.time()
195
+ return _pre
196
+
197
+ def fwd_hook(name):
198
+ def _fwd(mod, inp, out):
199
+ end = time.time()
200
+ duration_ms = (end - start_times.get(name, end)) * 1000
201
+ inp0 = inp[0] if isinstance(inp, (tuple, list)) else inp
202
+ macs = None
203
+ if isinstance(mod, nn.Conv2d):
204
+ macs = _macs_conv2d(inp0, mod, out)
205
+ elif isinstance(mod, nn.Linear):
206
+ macs = _macs_linear(inp0, mod, out)
207
+ params = sum(p.numel() for p in mod.parameters())
208
+ rows.append({
209
+ "Layer": name,
210
+ "Type": mod.__class__.__name__,
211
+ "Params": params,
212
+ "MACs": macs if macs is None else float(macs),
213
+ "Latency (ms)": round(duration_ms, 3),
214
+ })
215
+ return _fwd
216
+
217
+ for name, module in model.named_modules():
218
+ if len(list(module.children())) == 0: # leaf
219
+ handles.append(module.register_forward_pre_hook(pre_hook(name)))
220
+ handles.append(module.register_forward_hook(fwd_hook(name)))
221
 
222
+ with torch.no_grad():
223
+ model(sample_input)
224
+
225
+ for h in handles:
226
+ h.remove()
227
+
228
+ return pd.DataFrame(rows)
229
+
230
+
231
+ def maybe_compile(model, use_compile: bool):
232
+ if not use_compile:
233
+ return model
234
+ if not hasattr(torch, "compile"):
235
+ return model
236
+ try:
237
+ return torch.compile(model)
238
+ except Exception:
239
+ return model
240
+
241
+
242
+ def get_state_dict_size_mb(model: nn.Module) -> float:
243
+ buffer = io.BytesIO()
244
+ torch.save(model.state_dict(), buffer)
245
+ return len(buffer.getbuffer()) / 1e6
246
+
247
+
248
+ def prepare_image(image, transform_fn):
249
+ if image is None:
250
+ raise ValueError("No image provided")
251
+
252
+ if not isinstance(image, Image.Image):
253
+ if isinstance(image, np.ndarray) and image.dtype != np.uint8:
254
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
255
+ image = Image.fromarray(image.astype("uint8"))
256
+
257
+ image = image.convert("RGB")
258
+ tensor = transform_fn(image).unsqueeze(0)
259
+ return tensor
260
+
261
+
262
+ def grad_cam(image, model, device, transform_fn):
263
+ model.eval()
264
+ target_layer = None
265
+ for m in reversed(list(model.modules())):
266
+ if isinstance(m, nn.Conv2d):
267
+ target_layer = m
268
+ break
269
+ if target_layer is None:
270
+ raise ValueError("No Conv2d layer found for Grad-CAM")
271
+
272
+ activations = {}
273
+ gradients = {}
274
+
275
+ def fwd_hook(module, inp, out):
276
+ activations["value"] = out.detach()
277
+
278
+ def bwd_hook(module, grad_in, grad_out):
279
+ gradients["value"] = grad_out[0].detach()
280
+
281
+ handle_fwd = target_layer.register_forward_hook(fwd_hook)
282
+ handle_bwd = target_layer.register_full_backward_hook(bwd_hook)
283
+
284
+ img_t = prepare_image(image, transform_fn).to(device)
285
+ img_t.requires_grad_(True)
286
+
287
+ with torch.no_grad():
288
+ pass
289
+
290
+ out = model(img_t)
291
+ top1 = out.argmax(dim=1)
292
+ score = out[0, top1]
293
+ model.zero_grad()
294
+ score.backward()
295
+
296
+ act = activations["value"]
297
+ grad = gradients["value"]
298
+ weights = grad.mean(dim=(2, 3), keepdim=True)
299
+ cam = (weights * act).sum(dim=1, keepdim=True)
300
+ cam = F.relu(cam)
301
+ cam = cam.squeeze().cpu().numpy()
302
+ cam -= cam.min()
303
+ cam /= cam.max() + 1e-8
304
+
305
+ # Resize CAM to image size
306
+ cam_img = Image.fromarray(np.uint8(cam * 255)).resize(image.size, resample=Image.BILINEAR)
307
+ heatmap = np.array(cam_img)
308
+ heatmap_rgb = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
309
+ overlay = np.array(image.convert("RGB"), dtype=np.float32)
310
+ alpha = 0.35
311
+ blended = (overlay * (1 - alpha) + heatmap_rgb * alpha).clip(0, 255).astype("uint8")
312
+ blended_img = Image.fromarray(blended)
313
+
314
+ handle_fwd.remove()
315
+ handle_bwd.remove()
316
+ return blended_img
317
 
318
 
319
  # ---------------------------------------------
320
  # Inference Function (shared)
321
  # ---------------------------------------------
322
+ def run_inference(model, image, device, transform_fn, channels_last=False, warmup=False, use_amp=False):
323
  print(f" run_inference called, image type: {type(image)}")
324
+ img = prepare_image(image, transform_fn)
 
 
 
 
 
325
 
326
+ model = model.to(device)
327
+ img = img.to(device)
328
+
329
+ if channels_last and device.type == "cuda":
330
+ img = img.to(memory_format=torch.channels_last)
331
 
332
  if next(model.parameters()).dtype == torch.float16:
333
  img = img.half()
334
 
335
+ if warmup:
336
+ with torch.no_grad():
337
+ model(img)
338
+
339
  print(" Running model inference...")
340
  start = time.time()
341
+ amp_ctx = torch.cuda.amp.autocast(enabled=use_amp and device.type == "cuda")
342
+ with torch.no_grad(), amp_ctx:
343
  out = model(img)
344
  latency = (time.time() - start) * 1000
345
 
346
+ out_cpu = out.detach().cpu()
347
+ prob = torch.softmax(out_cpu, dim=1)[0]
348
  top5_prob, top5_idx = torch.topk(prob, 5)
349
 
350
  results = [(labels[i], float(top5_prob[j])) for j, i in enumerate(top5_idx)]
 
352
  return results, latency
353
 
354
 
355
+ def build_top5_plot(results_orig, results_other, other_label: str):
356
+ classes = []
357
+ for r in results_orig + results_other:
358
+ if r[0] not in classes:
359
+ classes.append(r[0])
360
+ orig_map = {r[0]: r[1] for r in results_orig}
361
+ other_map = {r[0]: r[1] for r in results_other}
362
+ orig_vals = [orig_map.get(c, 0.0) for c in classes]
363
+ other_vals = [other_map.get(c, 0.0) for c in classes]
364
+ x = np.arange(len(classes))
365
+ width = 0.35
366
+ fig, ax = plt.subplots(figsize=(8, 4))
367
+ ax.bar(x - width / 2, orig_vals, width, label="Original")
368
+ ax.bar(x + width / 2, other_vals, width, label=other_label)
369
+ ax.set_xticks(x)
370
+ ax.set_xticklabels(classes, rotation=20, ha="right")
371
+ ax.set_ylabel("Confidence")
372
+ ax.set_xlabel("Class")
373
+ ax.legend()
374
+ fig.tight_layout()
375
+ return fig
376
+
377
+
378
  # ---------------------------------------------
379
  # Gradio Functions With Options
380
  # ---------------------------------------------
381
+ def run_pruned(
382
+ img,
383
+ model_name,
384
+ method,
385
+ amount,
386
+ device_choice="auto",
387
+ channels_last=False,
388
+ use_compile=False,
389
+ use_amp=False,
390
+ export_ts=False,
391
+ export_onnx=False,
392
+ export_report=False,
393
+ export_state=True,
394
+ preset=None,
395
+ ):
396
  print("\n=== RUN PRUNED CALLED ===")
397
+ print(f"Image type: {type(img)}, Model: {model_name}, Method: {method}, Amount: {amount}")
398
+
399
  if img is None:
400
  print("ERROR: Image is None")
401
+ return (
402
+ {"Metric": ["Error"], "Original Model": ["No image uploaded"], "Pruned Model": [""]},
403
+ {"Class": [], "Original": [], "Pruned": []},
404
+ pd.DataFrame(),
405
+ []
406
+ )
407
+
408
+ if preset in PRESETS:
409
+ preset_cfg = PRESETS[preset]
410
+ device_choice = preset_cfg["device"]
411
+ channels_last = preset_cfg["channels_last"]
412
+ use_compile = preset_cfg["compile"]
413
+ use_amp = preset_cfg.get("amp", use_amp)
414
+ amount = preset_cfg.get("prune_amount", amount)
415
+
416
+ device = select_device(device_choice)
417
+ transform_fn = get_transform(model_name)
418
+
419
  # Run original model
420
  print("Running original model...")
421
+ fp32_model = get_fp32_model(model_name)
422
+ results_orig, latency_orig = run_inference(fp32_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
423
  print(f"Original model done. Latency: {latency_orig:.2f}ms")
424
+
425
  # Run pruned model
426
  print("Creating fresh model...")
427
+ fresh_model = clone_model(model_name)
428
  print("Applying pruning...")
429
  pruned_model = apply_pruning(fresh_model, amount=float(amount), method=method)
430
+ pruned_model = maybe_compile(pruned_model, use_compile)
431
  print("Running pruned model...")
432
+ results_pruned, latency_pruned = run_inference(pruned_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
433
  print(f"Pruned model done. Latency: {latency_pruned:.2f}ms")
434
 
435
+ # Model sizes (in-memory)
436
+ size_orig = get_state_dict_size_mb(fp32_model)
437
+ size_pruned = get_state_dict_size_mb(pruned_model)
 
 
 
 
 
 
 
 
 
438
  print(f"Model sizes - Original: {size_orig:.2f}MB, Pruned: {size_pruned:.2f}MB")
439
 
440
  # Comparison metrics - as DataFrame for Gradio
 
441
  metrics_df = pd.DataFrame({
442
  "Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
443
  "Original Model": [
 
453
  f"{size_pruned:.2f}"
454
  ]
455
  })
 
 
 
 
 
 
 
 
456
 
457
+ chart_fig = build_top5_plot(results_orig, results_pruned, "Pruned")
458
+ sparsity_df = compute_sparsity(pruned_model.cpu())
459
+
460
+ downloads = []
461
+ export_dir = Path("exports")
462
+ export_dir.mkdir(exist_ok=True)
463
+ sample_cpu = prepare_image(img, transform_fn)
464
+
465
+ if export_report:
466
+ report_path = export_dir / "pruned_report.json"
467
+ report = {
468
+ "model": model_name,
469
+ "pruning": {"method": method, "amount": float(amount)},
470
+ "metrics": metrics_df.to_dict(),
471
+ "top5_pruned": results_pruned,
472
+ "top5_original": results_orig,
473
+ }
474
+ report_path.write_text(json.dumps(report, indent=2))
475
+ downloads.append(str(report_path))
476
+
477
+ # Always allow state_dict download for reproducibility
478
+ if export_state:
479
+ state_path = export_dir / "pruned_state_dict.pth"
480
+ torch.save(pruned_model.state_dict(), state_path)
481
+ downloads.append(str(state_path))
482
+
483
+ if export_ts:
484
+ ts_path = export_dir / "pruned_model.ts"
485
+ try:
486
+ scripted = torch.jit.trace(pruned_model.cpu(), sample_cpu)
487
+ scripted.save(ts_path)
488
+ downloads.append(str(ts_path))
489
+ except Exception as exc:
490
+ print(f"TorchScript export failed: {exc}")
491
+
492
+ if export_onnx:
493
+ onnx_path = export_dir / "pruned_model.onnx"
494
+ try:
495
+ torch.onnx.export(
496
+ pruned_model.cpu(),
497
+ sample_cpu,
498
+ onnx_path,
499
+ input_names=["input"],
500
+ output_names=["output"],
501
+ opset_version=13,
502
+ dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
503
+ )
504
+ downloads.append(str(onnx_path))
505
+ except Exception as exc:
506
+ print(f"ONNX export failed: {exc}")
507
 
508
+ print("=== RUN PRUNED COMPLETE ===")
509
+ return metrics_df, chart_fig, sparsity_df, downloads
510
+
511
+
512
+ def run_quantized(
513
+ img,
514
+ model_name,
515
+ q_type,
516
+ device_choice="auto",
517
+ channels_last=False,
518
+ use_compile=False,
519
+ use_amp=False,
520
+ export_ts=False,
521
+ export_onnx=False,
522
+ export_report=False,
523
+ export_state=True,
524
+ preset=None,
525
+ ):
526
  print("\n=== RUN QUANTIZED CALLED ===")
527
+ print(f"Image type: {type(img)}, Model: {model_name}, Q-type: {q_type}")
528
+
529
  if img is None:
530
  print("ERROR: Image is None")
531
+ return (
532
+ {"Metric": ["Error"], "Original Model": ["No image uploaded"], "Quantized Model": [""]},
533
+ {"Class": [], "Original": [], "Quantized": []},
534
+ []
535
+ )
536
+
537
+ if preset in PRESETS:
538
+ preset_cfg = PRESETS[preset]
539
+ device_choice = preset_cfg["device"]
540
+ channels_last = preset_cfg["channels_last"]
541
+ use_compile = preset_cfg["compile"]
542
+ use_amp = preset_cfg.get("amp", use_amp)
543
+ q_type = preset_cfg.get("quant", q_type)
544
+
545
+ device = select_device(device_choice)
546
+ if q_type in {"dynamic", "weight_only"} and device.type != "cpu":
547
+ print("Dynamic/weight-only quantization uses CPU kernels; switching device to CPU.")
548
+ device = torch.device("cpu")
549
+ channels_last = False
550
+ use_amp = False
551
+ transform_fn = get_transform(model_name)
552
+
553
  # Run original model
554
  print("Running original model...")
555
+ fp32_model = get_fp32_model(model_name)
556
+ results_orig, latency_orig = run_inference(fp32_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
557
  print(f"Original model done. Latency: {latency_orig:.2f}ms")
558
+
559
  # Run quantized model
560
+ fresh_model = clone_model(model_name)
561
  quant_model = apply_quantization(fresh_model, q_type)
562
+ quant_model = maybe_compile(quant_model, use_compile)
563
+ results_quant, latency_quant = run_inference(quant_model, img, device, transform_fn, channels_last, warmup=True, use_amp=use_amp)
564
 
565
+ size_orig = get_state_dict_size_mb(fp32_model)
566
+ size_quant = get_state_dict_size_mb(quant_model)
 
 
 
567
 
 
 
568
  metrics_df = pd.DataFrame({
569
  "Metric": ["Top-1 Prediction", "Confidence", "Latency (ms)", "Model Size (MB)"],
570
  "Original Model": [
 
580
  f"{size_quant:.2f}"
581
  ]
582
  })
583
+
584
+ chart_fig = build_top5_plot(results_orig, results_quant, "Quantized")
585
+
586
+ downloads = []
587
+ export_dir = Path("exports")
588
+ export_dir.mkdir(exist_ok=True)
589
+ sample_cpu = prepare_image(img, transform_fn)
590
+
591
+ if export_report:
592
+ report_path = export_dir / "quant_report.json"
593
+ report = {
594
+ "model": model_name,
595
+ "quantization": q_type,
596
+ "metrics": metrics_df.to_dict(),
597
+ "top5_quantized": results_quant,
598
+ "top5_original": results_orig,
599
+ }
600
+ report_path.write_text(json.dumps(report, indent=2))
601
+ downloads.append(str(report_path))
602
+
603
+ if export_state:
604
+ state_path = export_dir / "quantized_state_dict.pth"
605
+ torch.save(quant_model.state_dict(), state_path)
606
+ downloads.append(str(state_path))
607
+
608
+ if export_ts:
609
+ ts_path = export_dir / "quantized_model.ts"
610
+ try:
611
+ scripted = torch.jit.trace(quant_model.cpu(), sample_cpu)
612
+ scripted.save(ts_path)
613
+ downloads.append(str(ts_path))
614
+ except Exception as exc:
615
+ print(f"TorchScript export failed: {exc}")
616
+
617
+ if export_onnx:
618
+ onnx_path = export_dir / "quantized_model.onnx"
619
+ try:
620
+ torch.onnx.export(
621
+ quant_model.cpu(),
622
+ sample_cpu,
623
+ onnx_path,
624
+ input_names=["input"],
625
+ output_names=["output"],
626
+ opset_version=13,
627
+ dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
628
+ )
629
+ downloads.append(str(onnx_path))
630
+ except Exception as exc:
631
+ print(f"ONNX export failed: {exc}")
632
 
633
  print("=== RUN QUANTIZED COMPLETE ===")
634
+ return metrics_df, chart_fig, downloads
635
+
636
+
637
+ def run_profile(image, model_name, variant, device_choice="auto"):
638
+ if image is None:
639
+ return pd.DataFrame()
640
+ device = select_device(device_choice)
641
+ if variant == "quant" and device.type != "cpu":
642
+ device = torch.device("cpu")
643
+ transform_fn = get_transform(model_name)
644
+ sample = prepare_image(image, transform_fn).to(device)
645
+
646
+ if variant == "fp32":
647
+ model = get_fp32_model(model_name).to(device)
648
+ elif variant == "pruned":
649
+ model = apply_pruning(clone_model(model_name), amount=0.4)
650
+ else:
651
+ model = apply_quantization(clone_model(model_name), "dynamic")
652
+
653
+ profile_df = layer_profile(model, sample)
654
+ return profile_df
655
+
656
+
657
+ def run_batch(images, model_name, mode, device_choice="auto"):
658
+ """Batch runner: returns per-image metrics and aggregate stats."""
659
+ if not images:
660
+ return pd.DataFrame(), pd.DataFrame()
661
+
662
+ device = select_device(device_choice)
663
+ transform_fn = get_transform(model_name)
664
+
665
+ per_image = []
666
+ latencies = []
667
+ labels_map = {}
668
+ expanded_files = []
669
+ temp_dirs = []
670
+
671
+ for path in images:
672
+ if isinstance(path, str) and path.endswith(".zip"):
673
+ td = TemporaryDirectory()
674
+ temp_dirs.append(td) # keep alive until function ends
675
+ with zipfile.ZipFile(path) as zf:
676
+ zf.extractall(td.name)
677
+ for root, _, files in os.walk(td.name):
678
+ for f in files:
679
+ if f.lower().endswith((".jpg", ".jpeg", ".png")):
680
+ expanded_files.append(os.path.join(root, f))
681
+ if f.lower() in {"labels.txt", "labels.csv"}:
682
+ with open(os.path.join(root, f)) as lf:
683
+ for line in lf:
684
+ parts = line.strip().split(",")
685
+ if len(parts) >= 2:
686
+ labels_map[parts[0]] = parts[1]
687
+ else:
688
+ expanded_files.append(path)
689
+
690
+ for path in expanded_files:
691
+ img = Image.open(path) if isinstance(path, str) else path
692
+ if mode == "prune":
693
+ metrics, _, _, _ = run_pruned(
694
+ img,
695
+ model_name,
696
+ "structured",
697
+ 0.4,
698
+ device_choice=device_choice,
699
+ export_state=False,
700
+ )
701
+ latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Pruned Model"].values[0])
702
+ top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Pruned Model"].values[0]
703
+ else:
704
+ metrics, _, _ = run_quantized(
705
+ img,
706
+ model_name,
707
+ "dynamic",
708
+ device_choice=device_choice,
709
+ export_state=False,
710
+ )
711
+ latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Quantized Model"].values[0])
712
+ top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Quantized Model"].values[0]
713
+
714
+ fname = os.path.basename(getattr(path, "name", path))
715
+ record = {"Image": fname, "Top-1": top1, "Latency (ms)": latency}
716
+ if fname in labels_map:
717
+ record["Label"] = labels_map[fname]
718
+ record["Correct"] = labels_map[fname] == top1
719
+ per_image.append(record)
720
+ latencies.append(latency)
721
+
722
+ per_image_df = pd.DataFrame(per_image)
723
+ summary = {
724
+ "count": len(latencies),
725
+ "mean_latency": float(np.mean(latencies)),
726
+ "median_latency": float(np.median(latencies)),
727
+ "max_latency": float(np.max(latencies)),
728
+ }
729
+ if "Correct" in per_image_df.columns:
730
+ summary["accuracy"] = float(per_image_df["Correct"].mean())
731
+
732
+ summary_df = pd.DataFrame({"Metric": list(summary.keys()), "Value": list(summary.values())})
733
+ return per_image_df, summary_df
734
+
735
+
736
+ def run_sweep(img, model_name, device_choice, experiments_json=None):
737
+ if img is None:
738
+ return pd.DataFrame(), pd.DataFrame()
739
+ default_experiments = [
740
+ {"mode": "prune", "amount": 0.2, "method": "structured"},
741
+ {"mode": "prune", "amount": 0.5, "method": "structured"},
742
+ {"mode": "quant", "q_type": "dynamic"},
743
+ {"mode": "quant", "q_type": "fp16"},
744
+ ]
745
+ try:
746
+ experiments = json.loads(experiments_json) if experiments_json else default_experiments
747
+ except Exception:
748
+ experiments = default_experiments
749
+
750
+ rows = []
751
+ for exp in experiments:
752
+ if exp.get("mode") == "prune":
753
+ metrics, _, _, _ = run_pruned(
754
+ img,
755
+ model_name,
756
+ exp.get("method", "structured"),
757
+ exp.get("amount", 0.4),
758
+ device_choice=device_choice,
759
+ export_state=False,
760
+ )
761
+ latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Pruned Model"].values[0])
762
+ size = float(metrics.loc[metrics["Metric"] == "Model Size (MB)", "Pruned Model"].values[0])
763
+ top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Pruned Model"].values[0]
764
+ rows.append({"mode": "prune", "amount": exp.get("amount"), "latency": latency, "size": size, "top1": top1})
765
+ else:
766
+ metrics, _, _ = run_quantized(
767
+ img,
768
+ model_name,
769
+ exp.get("q_type", "dynamic"),
770
+ device_choice=device_choice,
771
+ export_state=False,
772
+ )
773
+ latency = float(metrics.loc[metrics["Metric"] == "Latency (ms)", "Quantized Model"].values[0])
774
+ size = float(metrics.loc[metrics["Metric"] == "Model Size (MB)", "Quantized Model"].values[0])
775
+ top1 = metrics.loc[metrics["Metric"] == "Top-1 Prediction", "Quantized Model"].values[0]
776
+ rows.append({"mode": "quant", "q_type": exp.get("q_type"), "latency": latency, "size": size, "top1": top1})
777
+
778
+ df = pd.DataFrame(rows)
779
+ pareto_df = df.rename(columns={"latency": "Latency (ms)", "size": "Model Size (MB)", "top1": "Top-1"})
780
+ return df, pareto_df
781
+
782
+
783
+ def fastapi_snippet():
784
+ return """
785
+ from fastapi import FastAPI, UploadFile
786
+ from PIL import Image
787
+ import io, torch, timm
788
+ from app import get_transform, get_fp32_model, run_inference, select_device
789
+
790
+ app = FastAPI()
791
+ MODEL = "resnet50"
792
+ DEVICE = select_device("auto")
793
+ MODEL_OBJ = get_fp32_model(MODEL).to(DEVICE)
794
+ TRANSFORM = get_transform(MODEL)
795
+
796
+
797
+ @app.post('/predict')
798
+ async def predict(file: UploadFile):
799
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
800
+ results, latency = run_inference(MODEL_OBJ, img, DEVICE, TRANSFORM)
801
+ return {"top1": results[0][0], "confidence": results[0][1], "latency_ms": latency}
802
+ """
803
+
804
+
805
+ def dockerfile_snippet():
806
+ return """
807
+ FROM python:3.10-slim
808
+ WORKDIR /app
809
+ COPY requirements.txt .
810
+ RUN pip install --no-cache-dir -r requirements.txt
811
+ COPY . .
812
+ CMD ["python", "app.py", "--cli", "--image", "examples/cat.jpg"]
813
+ """
814
 
815
 
816
  # ---------------------------------------------
817
  # GRADIO UI
818
  # ---------------------------------------------
819
+ examples = [["examples/cat.jpg"], ["examples/dog.jpg"], ["examples/bird.jpg"], ["examples/car.jpg"], ["examples/elephant.jpg"]]
820
+
821
+
822
+ def create_demo():
823
+ with gr.Blocks() as demo:
824
+ gr.Markdown("# 🧠 Model Optimization Lab β€” Compare, Export, Benchmark")
825
+
826
+ device_opts = ["auto", "cpu"]
827
+ if torch.cuda.is_available():
828
+ device_opts.append("cuda")
829
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
830
+ device_opts.append("mps")
831
+ preset_opts = list(PRESETS.keys()) + ["custom"]
832
+
833
+ with gr.Tabs():
834
+ # ---- PRUNING TAB ----
835
+ with gr.Tab("Pruning"):
836
+ with gr.Row():
837
+ with gr.Column():
838
+ img_p = gr.Image(label="Upload Image")
839
+ model_p = gr.Dropdown(MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model")
840
+ preset_p = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
841
+ method_p = gr.Dropdown(["unstructured", "structured"], value="structured", label="Pruning Method")
842
+ amount_p = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.4, label="Pruning Amount")
843
+ device_p = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
844
+ channels_last_p = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
845
+ amp_p = gr.Checkbox(label="Mixed precision (AMP)", value=True)
846
+ compile_p = gr.Checkbox(label="Torch compile (PyTorch 2)")
847
+ export_ts_p = gr.Checkbox(label="Export TorchScript")
848
+ export_onnx_p = gr.Checkbox(label="Export ONNX")
849
+ export_report_p = gr.Checkbox(label="Export JSON report", value=True)
850
+ btn_p = gr.Button("Run Pruned Model")
851
+ gr.Examples(examples=examples, inputs=img_p)
852
+
853
+ with gr.Column():
854
+ metrics_p = gr.Dataframe(label="πŸ“Š Comparison Metrics", headers=["Metric", "Original Model", "Pruned Model"])
855
+ chart_p = gr.Plot(label="🎯 Top-5 Predictions Comparison")
856
+ sparsity_p = gr.Dataframe(label="Layer sparsity (%)")
857
+ downloads_p = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
858
+
859
+ btn_p.click(
860
+ fn=run_pruned,
861
+ inputs=[
862
+ img_p,
863
+ model_p,
864
+ method_p,
865
+ amount_p,
866
+ device_p,
867
+ channels_last_p,
868
+ compile_p,
869
+ amp_p,
870
+ export_ts_p,
871
+ export_onnx_p,
872
+ export_report_p,
873
+ gr.State(True),
874
+ preset_p,
875
+ ],
876
+ outputs=[metrics_p, chart_p, sparsity_p, downloads_p],
877
+ )
878
+
879
+ # ---- QUANTIZATION TAB ----
880
+ with gr.Tab("Quantization"):
881
+ with gr.Row():
882
+ with gr.Column():
883
+ img_q = gr.Image(label="Upload Image")
884
+ model_q = gr.Dropdown(MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model")
885
+ preset_q = gr.Dropdown(preset_opts, value="custom", label="Hardware Preset")
886
+ q_type = gr.Dropdown(["dynamic", "weight_only", "fp16"], value="dynamic", label="Quantization Type")
887
+ device_q = gr.Dropdown(device_opts, value=device_opts[0], label="Device")
888
+ channels_last_q = gr.Checkbox(label="Channels-last input (CUDA)", value=True)
889
+ amp_q = gr.Checkbox(label="Mixed precision (AMP)", value=True)
890
+ compile_q = gr.Checkbox(label="Torch compile (PyTorch 2)")
891
+ export_ts_q = gr.Checkbox(label="Export TorchScript")
892
+ export_onnx_q = gr.Checkbox(label="Export ONNX")
893
+ export_report_q = gr.Checkbox(label="Export JSON report", value=True)
894
+ btn_q = gr.Button("Run Quantized Model")
895
+ gr.Examples(examples=examples, inputs=img_q)
896
+
897
+ with gr.Column():
898
+ metrics_q = gr.Dataframe(label="πŸ“Š Comparison Metrics", headers=["Metric", "Original Model", "Quantized Model"])
899
+ chart_q = gr.Plot(label="🎯 Top-5 Predictions Comparison")
900
+ downloads_q = gr.Files(label="Exports (state_dict / TorchScript / ONNX / report)")
901
+
902
+ btn_q.click(
903
+ fn=run_quantized,
904
+ inputs=[
905
+ img_q,
906
+ model_q,
907
+ q_type,
908
+ device_q,
909
+ channels_last_q,
910
+ compile_q,
911
+ amp_q,
912
+ export_ts_q,
913
+ export_onnx_q,
914
+ export_report_q,
915
+ gr.State(True),
916
+ preset_q,
917
+ ],
918
+ outputs=[metrics_q, chart_q, downloads_q],
919
+ )
920
+
921
+ return demo
922
+
923
+
924
+ def main():
925
+ parser = argparse.ArgumentParser(description="Model optimization lab")
926
+ parser.add_argument("--cli", action="store_true", help="Run in CLI mode (no UI)")
927
+ parser.add_argument("--mode", choices=["prune", "quant"], default="prune", help="Optimization mode in CLI")
928
+ parser.add_argument("--image", type=str, help="Path to an image for CLI mode")
929
+ parser.add_argument("--model", type=str, default=MODEL_OPTIONS[0], choices=MODEL_OPTIONS)
930
+ parser.add_argument("--device", type=str, default="auto")
931
+ args = parser.parse_args()
932
+
933
+ if args.cli:
934
+ if not args.image:
935
+ raise SystemExit("--image is required in CLI mode")
936
+ img = Image.open(args.image)
937
+ if args.mode == "prune":
938
+ metrics, _, downloads = run_pruned(img, args.model, "structured", 0.4, device_choice=args.device)
939
+ else:
940
+ metrics, _, downloads = run_quantized(img, args.model, "dynamic", device_choice=args.device)
941
+ print(metrics)
942
+ print("Exports:", downloads)
943
+ return
944
+
945
+ demo = create_demo()
946
+ demo.launch()
947
+
948
+
949
+ if __name__ == "__main__":
950
+ main()