Nathan12 commited on
Commit
d1ac7b4
Β·
1 Parent(s): 7cc1ffa

update app

Browse files
Files changed (1) hide show
  1. app.py +388 -108
app.py CHANGED
@@ -1,26 +1,226 @@
1
- import fasterai
2
- from fasterai.sparse.all import *
3
- from fasterai.prune.all import *
4
- import torch
5
  import gradio as gr
 
 
 
 
 
 
 
6
  import os
 
 
 
 
 
 
 
 
 
 
 
7
  from torch.ao.quantization import get_default_qconfig_mapping
8
- import torch.ao.quantization.quantize_fx as quantize_fx
9
  from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- class Quant():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self, backend="x86"):
13
  self.qconfig = get_default_qconfig_mapping(backend)
14
 
15
  def quantize(self, model):
16
- x = torch.randn(3, 224, 224)
17
- model_prepared = prepare_fx(model.eval(), self.qconfig, x)
18
  return convert_fx(model_prepared)
19
 
20
-
21
  def optimize_model(input_model, sparsity, context, criteria):
22
-
23
- model = torch.load(input_model, weights_only=False)
24
  model = model.eval()
25
  model = model.to('cpu')
26
  sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
@@ -32,146 +232,226 @@ def optimize_model(input_model, sparsity, context, criteria):
32
  qu_model = qu.quantize(model)
33
 
34
  comp_path = "./comp_model.pth"
35
-
36
  scripted = torch.jit.script(qu_model)
37
  torch.jit.save(scripted, comp_path)
38
- #torch.save(qu_model, comp_path)
39
 
40
- return comp_path
 
 
41
 
42
- import matplotlib.pyplot as plt
43
- import seaborn as sns
44
- import io
45
- import numpy as np
 
 
 
 
 
 
 
 
 
 
46
 
47
- def get_model_size(model_path):
48
- """Get model size in MB"""
49
- size_bytes = os.path.getsize(model_path)
50
- size_mb = size_bytes / (1024 * 1024)
51
- return round(size_mb, 2)
52
 
53
- def create_size_comparison_plot(original_size, compressed_size):
54
- """Create a bar plot comparing model sizes"""
55
- # Set seaborn style
 
 
 
 
56
  sns.set_style("darkgrid")
57
-
58
- # Create figure with higher DPI for better resolution
59
  fig = plt.figure(figsize=(10, 6), dpi=150)
60
-
61
- # Set transparent background
62
  fig.patch.set_alpha(0.0)
63
  ax = plt.gca()
64
  ax.patch.set_alpha(0.0)
65
-
66
- # Plot bars with custom colors and alpha
67
  bars = plt.bar(['Original', 'Compressed'],
68
- [original_size, compressed_size],
69
- color=['#FF6B00', '#FF9F1C'],
70
- alpha=0.8,
71
- width=0.6)
72
-
73
- # Add size labels on top of bars with improved styling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  for bar in bars:
75
  height = bar.get_height()
76
- plt.text(bar.get_x() + bar.get_width()/2., height + (height * 0.01),
77
- f'{height:.2f} MB',
78
- ha='center', va='bottom',
79
- fontsize=11,
80
- fontweight='bold',
81
- color='white')
82
-
83
- # Calculate compression percentage
84
- compression_ratio = ((original_size - compressed_size) / original_size) * 100
85
-
86
- # Customize title and labels with better visibility
87
- plt.title(f'Model Size Comparison\nCompression: {compression_ratio:.1f}%',
88
  fontsize=14,
89
  fontweight='bold',
90
  pad=20,
91
  color='white')
92
-
93
- plt.xlabel('Model Version',
94
- fontsize=12,
95
- fontweight='bold',
96
- labelpad=10,
97
- color='white')
98
-
99
- plt.ylabel('Size (MB)',
100
- fontsize=12,
101
- fontweight='bold',
102
- labelpad=10,
103
- color='white')
104
-
105
- # Customize grid
106
  ax.grid(alpha=0.2, color='gray')
107
-
108
- # Remove top and right spines
109
  sns.despine()
110
-
111
- # Set y-axis limits with some padding
112
- max_value = max(original_size, compressed_size)
113
- plt.ylim(0, max_value * 1.2)
114
-
115
- # Add more y-axis ticks
116
- plt.yticks(np.linspace(0, max_value * 1.2, 10))
117
-
118
- # Make tick labels white
119
  ax.tick_params(colors='white')
120
  for spine in ax.spines.values():
121
  spine.set_color('white')
122
-
123
- # Format axes with white text
124
  ax.xaxis.label.set_color('white')
125
  ax.yaxis.label.set_color('white')
126
  ax.tick_params(axis='x', colors='white')
127
  ax.tick_params(axis='y', colors='white')
128
-
129
- # Format y-axis tick labels
130
  ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
131
-
132
- # Adjust layout to prevent label cutoff
133
- plt.tight_layout()
134
-
135
  return fig
136
 
137
- def main_interface(model_name, sparsity, action):
138
  import torchvision.models as models
139
 
140
- model_mapping = {
141
- 'ResNet18': models.resnet18(pretrained=False),
142
- 'ResNet50': models.resnet50(pretrained=False),
143
- 'MobileNetV2': models.mobilenet_v2(pretrained=False),
144
- 'EfficientNet-B0': models.efficientnet_b0(pretrained=False),
145
- 'VGG16': models.vgg16(pretrained=False),
146
- }
 
 
 
 
 
147
 
148
- model = model_mapping[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Save model temporarily
151
- temp_path = "./temp_model.pth"
152
- torch.save(model, temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- original_size = get_model_size(temp_path)
155
 
156
- try:
157
- compressed_path = optimize_model(temp_path, sparsity, 'local', "large_final")
158
- compressed_size = get_model_size(compressed_path)
159
- size_plot = create_size_comparison_plot(original_size, compressed_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- return size_plot
162
- finally:
163
- # Clean up temporary file
164
- if os.path.exists(temp_path):
165
- os.remove(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
- available_models = ['ResNet18', 'ResNet50', 'MobileNetV2', 'EfficientNet-B0', 'VGG16']
 
 
 
 
 
 
169
 
170
  iface = gr.Interface(
171
- fn=main_interface,
172
  inputs=[
173
  gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
174
- gr.Slider(label="Compression Level", minimum=0, maximum=100, value=50),
 
 
175
  ],
176
  outputs=[
177
  gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
 
 
 
 
 
1
  import gradio as gr
2
+ import plotly
3
+
4
+ # %% ../nbs/00_benchmark.ipynb 5
5
+ import torch
6
+ import time
7
+ from codecarbon import OfflineEmissionsTracker
8
+ import numpy as np
9
  import os
10
+ from thop import profile, clever_format
11
+ from thop.vision.basic_hooks import count_convNd, count_linear
12
+
13
+ # Map quantized modules to existing conv/linear counters
14
+ import torch.ao.nn.quantized as nnq
15
+ import torch.ao.nn.intrinsic.quantized as nniq
16
+
17
+ from tqdm.notebook import tqdm
18
+ from torchprofile import profile_macs
19
+ from fasterai.sparse.all import *
20
+ from fasterai.prune.all import *
21
  from torch.ao.quantization import get_default_qconfig_mapping
 
22
  from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
23
+ import matplotlib.pyplot as plt
24
+ import seaborn as sns
25
+ import io
26
+ import copy
27
+
28
+ # Simple in-memory caches to avoid recomputation across UI interactions
29
+ _MODEL_CACHE = {}
30
+ _COMPRESSED_CACHE = {}
31
+
32
+ # %% ../nbs/00_benchmark.ipynb 7
33
+ def get_model_size(model, temp_path="temp_model.pth"):
34
+ """Return model disk size in bytes.
35
+
36
+ - If model is a path string, returns file size.
37
+ - If model is an nn.Module, saves state_dict to temp and measures size.
38
+ - If model is a ScriptModule, saves via torch.jit.save and measures size.
39
+ """
40
+ if isinstance(model, str) and os.path.exists(model):
41
+ return os.path.getsize(model)
42
+
43
+ try:
44
+ torch.save(model.state_dict(), temp_path)
45
+ except Exception:
46
+ # Fallback for ScriptModules or objects without state_dict
47
+ try:
48
+ torch.jit.save(model, temp_path)
49
+ except Exception:
50
+ torch.save(model, temp_path)
51
+
52
+ model_size = os.path.getsize(temp_path)
53
+ os.remove(temp_path)
54
+
55
+ return model_size
56
+
57
+ # %% ../nbs/00_benchmark.ipynb 8
58
+ def get_num_parameters(model):
59
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
60
+
61
+
62
+ # %% ../nbs/00_benchmark.ipynb 11
63
+ @torch.inference_mode()
64
+ def evaluate_cpu_speed(model, dummy_input, warmup_rounds=5, test_rounds=25):
65
+ device = torch.device("cpu")
66
+ model.eval()
67
+ model.to(device)
68
+ dummy_input = dummy_input.to(device)
69
+
70
+ # Warm up CPU
71
+ for _ in range(warmup_rounds):
72
+ _ = model(dummy_input)
73
+
74
+ # Measure Latency
75
+ latencies = []
76
+ for _ in range(test_rounds):
77
+ start_time = time.perf_counter()
78
+ _ = model(dummy_input)
79
+ end_time = time.perf_counter()
80
+ latencies.append(end_time - start_time)
81
+
82
+ latencies = np.array(latencies) * 1000 # Convert to milliseconds
83
+ mean_latency = np.mean(latencies)
84
+ std_latency = np.std(latencies)
85
+
86
+ # Measure Throughput
87
+ throughput = dummy_input.size(0) * 1000 / mean_latency # Inferences per second
88
+
89
+ return mean_latency, std_latency, throughput
90
+
91
+ # %% ../nbs/00_benchmark.ipynb 13
92
+ @torch.inference_mode()
93
+ def get_model_macs(model, inputs) -> int:
94
+ args = (inputs,) if not isinstance(inputs, (tuple, list)) else tuple(inputs)
95
+ try:
96
+ return profile_macs(model, args)
97
+ except Exception:
98
+ try:
99
+ custom_ops = {
100
+ nnq.Conv2d: count_convNd,
101
+ nniq.ConvReLU2d: count_convNd,
102
+ nnq.Linear: count_linear,
103
+ nniq.LinearReLU: count_linear,
104
+ }
105
+ macs_val, _ = profile(model, inputs=args, custom_ops=custom_ops)
106
+ return macs_val
107
+ except Exception:
108
+ return 0
109
+
110
+
111
+ # %% ../nbs/00_benchmark.ipynb 16
112
+ @torch.inference_mode()
113
+ def evaluate_emissions(model, dummy_input, warmup_rounds=5, test_rounds=20):
114
+ device = torch.device("cpu")
115
+ model.eval()
116
+ model.to(device)
117
+ dummy_input = dummy_input.to(device)
118
 
119
+ # Warm up GPU
120
+ for _ in range(warmup_rounds):
121
+ _ = model(dummy_input)
122
+
123
+ # Measure Latency
124
+ tracker = OfflineEmissionsTracker(country_iso_code="USA")
125
+ tracker.start()
126
+ for _ in range(test_rounds):
127
+ _ = model(dummy_input)
128
+ tracker.stop()
129
+ total_emissions = tracker.final_emissions
130
+ total_energy_consumed = tracker.final_emissions_data.energy_consumed
131
+
132
+ # Calculate average emissions and energy consumption per inference
133
+ average_emissions_per_inference = total_emissions / test_rounds
134
+ average_energy_per_inference = total_energy_consumed / test_rounds
135
+
136
+ return average_emissions_per_inference, average_energy_per_inference
137
+
138
+ # %% ../nbs/00_benchmark.ipynb 18
139
+ @torch.inference_mode()
140
+ def benchmark(model, dummy_input):
141
+ # Model Size
142
+ print('disk size')
143
+ disk_size = get_model_size(model)
144
+
145
+ # CPU Speed
146
+ print('cpu speed')
147
+ cpu_latency, cpu_std_latency, cpu_throughput = evaluate_cpu_speed(model, dummy_input)
148
+
149
+ # Model MACs and parameters with fallbacks
150
+ print('macs')
151
+ macs_str = "0.000G"
152
+ params_str = "0.000M"
153
+ try:
154
+ macs_val, params_val = profile(model, inputs=(dummy_input, ))
155
+ macs_str, params_str = clever_format([macs_val, params_val], "%.3f")
156
+ except Exception:
157
+ try:
158
+ macs_val = profile_macs(model, (dummy_input,))
159
+ macs_str = clever_format([macs_val], "%.3f")[0]
160
+ except Exception:
161
+ macs_str = "0.000G"
162
+ try:
163
+ params_val = sum(p.numel() for p in getattr(model, 'parameters', lambda: [])() if getattr(p, 'requires_grad', False))
164
+ # convert to M
165
+ params_str = f"{params_val/1e6:.3f}M"
166
+ except Exception:
167
+ params_str = "0.000M"
168
+
169
+ print('emissions')
170
+ # Emissions
171
+ avg_emissions, avg_energy = evaluate_emissions(model, dummy_input)
172
+
173
+ # Print results
174
+ try:
175
+ print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {params_str} parameters")
176
+ except Exception:
177
+ pass
178
+ print(f"CPU Latency: {cpu_latency:.3f} ms (Β± {cpu_std_latency:.3f} ms)")
179
+ print(f"CPU Throughput: {cpu_throughput:.2f} inferences/sec")
180
+ print(f"Model MACs: {macs_str}")
181
+ print(f"Average Carbon Emissions per Inference: {avg_emissions*1e3:.6f} gCO2e")
182
+ print(f"Average Energy Consumption per Inference: {avg_energy*1e3:.6f} Wh")
183
+
184
+ return {
185
+
186
+ 'disk_size': disk_size,
187
+ 'num_parameters': params_str,
188
+ 'cpu_latency': cpu_latency,
189
+ 'cpu_throughput': cpu_throughput,
190
+ 'macs': macs_str,
191
+ 'avg_emissions': avg_emissions,
192
+ 'avg_energy': avg_energy
193
+
194
+ }
195
+ def parse_metric_value(value_str):
196
+ """Convert string values with units (M, G) to float"""
197
+ if isinstance(value_str, (int, float)):
198
+ return float(value_str)
199
+
200
+ value_str = str(value_str)
201
+ if 'G' in value_str:
202
+ return float(value_str.replace('G', '')) * 1000 # Convert G to M
203
+ elif 'M' in value_str:
204
+ return float(value_str.replace('M', '')) # Keep in M
205
+ elif 'K' in value_str:
206
+ return float(value_str.replace('K', '')) / 1000 # Convert K to M
207
+ else:
208
+ return float(value_str)
209
+
210
+ # Compression and visualization utilities (merged from Compressor)
211
+ class Quant:
212
  def __init__(self, backend="x86"):
213
  self.qconfig = get_default_qconfig_mapping(backend)
214
 
215
  def quantize(self, model):
216
+ example_inputs = (torch.randn(1, 3, 224, 224),)
217
+ model_prepared = prepare_fx(model.eval(), self.qconfig, example_inputs)
218
  return convert_fx(model_prepared)
219
 
220
+ """
221
  def optimize_model(input_model, sparsity, context, criteria):
222
+ #model = torch.load(input_model)
223
+ model = torch.load(input_model, weights_only=False, map_location='cpu')
224
  model = model.eval()
225
  model = model.to('cpu')
226
  sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
 
232
  qu_model = qu.quantize(model)
233
 
234
  comp_path = "./comp_model.pth"
 
235
  scripted = torch.jit.script(qu_model)
236
  torch.jit.save(scripted, comp_path)
 
237
 
238
+ #return comp_path
239
+ return qu_model
240
+ """
241
 
242
+ def prune_model(input_model, sparsity, context, criteria):
243
+ # Accept either a path or an nn.Module
244
+ if isinstance(input_model, str):
245
+ model = torch.load(input_model, weights_only=False, map_location='cpu')
246
+ else:
247
+ model = input_model
248
+ model = model.eval()
249
+ model = model.to('cpu')
250
+ sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
251
+ sp.sparsify_model(sparsity)
252
+ sp._clean_buffers()
253
+ pr = Pruner(model, sparsity, context, criteria=eval(criteria))
254
+ pr.prune_model()
255
+ return pr.model
256
 
257
+ def quantize_model(model):
258
+ qu = Quant()
259
+ qu_model = qu.quantize(model)
260
+ return qu_model
 
261
 
262
+ def optimize_model(model, sparsity, context, criteria):
263
+ model = prune_model(model, sparsity, context, criteria)
264
+ model = quantize_model(model)
265
+ return model
266
+
267
+
268
+ def create_size_comparison_plot(before_results, after_results, metrics):
269
  sns.set_style("darkgrid")
270
+ # Increase figure size height to accommodate labels better
 
271
  fig = plt.figure(figsize=(10, 6), dpi=150)
 
 
272
  fig.patch.set_alpha(0.0)
273
  ax = plt.gca()
274
  ax.patch.set_alpha(0.0)
 
 
275
  bars = plt.bar(['Original', 'Compressed'],
276
+ [before_results, after_results],
277
+ color=['#FF6B00', '#FF9F1C'],
278
+ alpha=0.8,
279
+ width=0.6)
280
+ # Dynamic units per metric
281
+ unit_label_map = {
282
+ 'Latency': 'Latency (ms)',
283
+ 'Size': 'Size (MB)',
284
+ 'MACs': 'MACs (GMAC)',
285
+ 'Energy': 'Energy (mWh)',
286
+ 'Emissions': 'Emissions (mgCO2e)'
287
+ }
288
+ def format_value(val, metric):
289
+ try:
290
+ fval = float(val)
291
+ except Exception:
292
+ fval = 0.0
293
+ if metric == 'Latency':
294
+ return f"{fval:.2f} ms"
295
+ if metric == 'Size':
296
+ return f"{fval:.2f} MB"
297
+ if metric == 'MACs':
298
+ return f"{fval:.3f} GMAC"
299
+ if metric == 'Energy':
300
+ return f"{fval:.3f} mWh"
301
+ if metric == 'Emissions':
302
+ return f"{fval:.3f} mgCO2e"
303
+ return f"{fval:.3f}"
304
+ # Annotate bars with values + units
305
  for bar in bars:
306
  height = bar.get_height()
307
+ offset = (height * 0.02) if height else 0.05
308
+ plt.text(bar.get_x() + bar.get_width()/2., height + offset,
309
+ format_value(height, metrics),
310
+ ha='center', va='bottom',
311
+ fontsize=11,
312
+ fontweight='bold',
313
+ color='white')
314
+ compression_ratio = ((before_results - after_results) / before_results) * 100 if before_results else 0
315
+ plt.title(f'Model Compression: {compression_ratio:.1f}%',
 
 
 
316
  fontsize=14,
317
  fontweight='bold',
318
  pad=20,
319
  color='white')
320
+ plt.xlabel('Model Version', fontsize=12, fontweight='bold', labelpad=10, color='white')
321
+ plt.ylabel(unit_label_map.get(metrics, metrics), fontsize=12, fontweight='bold', labelpad=10, color='white')
 
 
 
 
 
 
 
 
 
 
 
 
322
  ax.grid(alpha=0.2, color='gray')
 
 
323
  sns.despine()
324
+ # Use scientific notation for small Energy/Emissions values
325
+ if metrics in ('Energy', 'Emissions'):
326
+ ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 3))
327
+ try:
328
+ max_value = max(float(before_results), float(after_results))
329
+ except Exception:
330
+ max_value = float(before_results or after_results or 1)
331
+ plt.ylim(0, max_value * 1.3) # Increased upper limit
332
+ plt.yticks(np.linspace(0, max_value * 1.3, 10))
333
  ax.tick_params(colors='white')
334
  for spine in ax.spines.values():
335
  spine.set_color('white')
 
 
336
  ax.xaxis.label.set_color('white')
337
  ax.yaxis.label.set_color('white')
338
  ax.tick_params(axis='x', colors='white')
339
  ax.tick_params(axis='y', colors='white')
 
 
340
  ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
341
+ plt.tight_layout(pad=3.5) # Increased padding from 2.5 to 3.5
 
 
 
342
  return fig
343
 
344
+ def benchmark_interface(model_name, compression_level, metrics):
345
  import torchvision.models as models
346
 
347
+ # Cache base models by name
348
+ if model_name not in _MODEL_CACHE:
349
+ model_mapping = {
350
+ 'ResNet18': models.resnet18(weights=None),
351
+ 'ResNet50': models.resnet50(weights=None),
352
+ 'MobileNetV2': models.mobilenet_v2(weights=None),
353
+ 'EfficientNet-B0': models.efficientnet_b0(weights=None),
354
+ 'VGG16': models.vgg16(weights=None),
355
+ }
356
+ _MODEL_CACHE[model_name] = model_mapping[model_name]
357
+ model = _MODEL_CACHE[model_name]
358
+ dummy_input = torch.randn(1, 3, 224, 224)
359
 
360
+ # Benchmark before (convert to readable units for plotting)
361
+ if metrics == 'Latency':
362
+ before_results, *_ = evaluate_cpu_speed(model, dummy_input)
363
+ elif metrics == 'Throughput':
364
+ *_, before_results = evaluate_cpu_speed(model, dummy_input)
365
+ elif metrics == 'Size':
366
+ before_results = get_model_size(model) / 1e6 # MB
367
+ elif metrics == 'MACs':
368
+ before_results = get_model_macs(model, dummy_input) / 1e9 # GMAC
369
+ elif metrics == 'Energy':
370
+ _, energy_kwh = evaluate_emissions(model, dummy_input)
371
+ before_results = energy_kwh * 1e6 # mWh
372
+ elif metrics == 'Emissions':
373
+ emissions_kg, _ = evaluate_emissions(model, dummy_input)
374
+ before_results = emissions_kg * 1e6 # mgCO2e
375
+ else:
376
+ raise ValueError(f"Invalid metric: {metrics}")
377
 
378
+ # Build or reuse compressed model for the selected compression level
379
+ cache_key = (model_name, compression_level)
380
+ if cache_key not in _COMPRESSED_CACHE:
381
+ sparsity = compression_values[compression_level]
382
+ model_for_pruning = copy.deepcopy(model)
383
+ comp_model = prune_model(model_for_pruning, sparsity, "local", "large_final")
384
+ _COMPRESSED_CACHE[cache_key] = comp_model
385
+ else:
386
+ comp_model = _COMPRESSED_CACHE[cache_key]
387
+
388
+ # Compute pre-quantization MACs if requested (more robust for tracing)
389
+ if metrics == 'MACs':
390
+ after_results = get_model_macs(comp_model, dummy_input) / 1e9 # GMAC
391
+
392
+ # Quantize lazily and cache the quantized variant too
393
+ q_cache_key = (model_name, compression_level, 'quant')
394
+ if q_cache_key not in _COMPRESSED_CACHE:
395
+ q_model = quantize_model(comp_model)
396
+ q_model.eval()
397
+ _COMPRESSED_CACHE[q_cache_key] = q_model
398
+ else:
399
+ q_model = _COMPRESSED_CACHE[q_cache_key]
400
 
 
401
 
402
+ if metrics == 'Latency':
403
+ after_results, *_ = evaluate_cpu_speed(q_model, dummy_input)
404
+ elif metrics == 'Throughput':
405
+ *_, after_results = evaluate_cpu_speed(q_model, dummy_input)
406
+ elif metrics == 'Size':
407
+ after_results = get_model_size(q_model) / 1e6 # MB
408
+ elif metrics == 'MACs':
409
+ # already computed above (pre-quantization for better compatibility)
410
+ pass
411
+ elif metrics == 'Energy':
412
+ _, energy_kwh_after = evaluate_emissions(q_model, dummy_input)
413
+ after_results = energy_kwh_after * 1e6 # mWh
414
+ elif metrics == 'Emissions':
415
+ emissions_kg_after, _ = evaluate_emissions(q_model, dummy_input)
416
+ after_results = emissions_kg_after * 1e6 # mgCO2e
417
+ else:
418
+ raise ValueError(f"Invalid metric: {metrics}")
419
 
420
+
421
+ # Build plots
422
+ size_plot = create_size_comparison_plot(before_results, after_results, metrics)
423
+ return size_plot
424
+ available_models = [
425
+ 'ResNet18',
426
+ 'ResNet50',
427
+ 'MobileNetV2',
428
+ 'EfficientNet-B0',
429
+ 'VGG16'
430
+ ]
431
+
432
+ compression_values = {
433
+ 'Mild 🐌': 25,
434
+ 'Balanced 🐒': 50,
435
+ 'Aggressive πŸ‡': 75,
436
+ 'Extreme 🐎': 90
437
+ }
438
 
439
 
440
+ metrics = [
441
+ 'Latency',
442
+ 'Size',
443
+ 'MACs',
444
+ 'Energy',
445
+ 'Emissions',
446
+ ]
447
 
448
  iface = gr.Interface(
449
+ fn=benchmark_interface,
450
  inputs=[
451
  gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
452
+ gr.Radio(choices=list(compression_values.keys()), label="Compression Level", value='Balanced 🐒'),
453
+ #gr.Radio(choices=list(target_device.keys()), label="Target Device", value='CPU'),
454
+ gr.Radio(choices=metrics, label="Comparison Metric", value='Latency'),
455
  ],
456
  outputs=[
457
  gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot