Compressor / app.py
Nathan12's picture
fix latency
578c966
import gradio as gr
# %% ../nbs/00_benchmark.ipynb 5
import torch
import time
from codecarbon import OfflineEmissionsTracker
import numpy as np
import os
from thop import profile, clever_format
from thop.vision.basic_hooks import count_convNd, count_linear
# Map quantized modules to existing conv/linear counters
import torch.ao.nn.quantized as nnq
import torch.ao.nn.intrinsic.quantized as nniq
from tqdm.notebook import tqdm
from torchprofile import profile_macs
from fasterai.sparse.all import *
from fasterai.prune.all import *
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
import matplotlib.pyplot as plt
import seaborn as sns
import io
import copy
# Simple in-memory caches to avoid recomputation across UI interactions
_MODEL_CACHE = {}
_COMPRESSED_CACHE = {}
# %% ../nbs/00_benchmark.ipynb 7
def get_model_size(model, temp_path="temp_model.pth"):
"""Return model disk size in bytes.
- If model is a path string, returns file size.
- If model is an nn.Module, saves state_dict to temp and measures size.
- If model is a ScriptModule, saves via torch.jit.save and measures size.
"""
if isinstance(model, str) and os.path.exists(model):
return os.path.getsize(model)
try:
torch.save(model.state_dict(), temp_path)
except Exception:
# Fallback for ScriptModules or objects without state_dict
try:
torch.jit.save(model, temp_path)
except Exception:
torch.save(model, temp_path)
model_size = os.path.getsize(temp_path)
os.remove(temp_path)
return model_size
# %% ../nbs/00_benchmark.ipynb 8
def get_num_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Warm up a model on CPU to stabilize kernel selection and prepack weights
@torch.inference_mode()
def warmup_model(model, num_warmup: int = 10, input_shape=(1, 3, 224, 224)):
try:
model.eval()
device = torch.device("cpu")
model.to(device)
dummy_input = torch.randn(*input_shape, device=device)
for _ in range(num_warmup):
_ = model(dummy_input)
except Exception:
pass
return model
# %% ../nbs/00_benchmark.ipynb 11
@torch.inference_mode()
def evaluate_cpu_speed(model, dummy_input, warmup_rounds=5, test_rounds=25):
device = torch.device("cpu")
model.eval()
model.to(device)
dummy_input = dummy_input.to(device)
# Warm up CPU
for _ in range(warmup_rounds):
_ = model(dummy_input)
# Measure Latency
latencies = []
for _ in range(test_rounds):
start_time = time.perf_counter()
_ = model(dummy_input)
end_time = time.perf_counter()
latencies.append(end_time - start_time)
latencies = np.array(latencies) * 1000 # Convert to milliseconds
mean_latency = np.mean(latencies)
std_latency = np.std(latencies)
# Measure Throughput
throughput = dummy_input.size(0) * 1000 / mean_latency # Inferences per second
return mean_latency, std_latency, throughput
import numpy as np, copy, time
try:
from torch.utils.benchmark import Timer
_HAS_TBENCH = True
except Exception:
_HAS_TBENCH = False
@torch.inference_mode()
def evaluate_cpu_speed_raw(model, dummy_input, warmup_rounds=10, test_rounds=31):
# Use the SAME instance (no deepcopy) to keep any benign prepack/caches.
m = model.eval().to("cpu")
x = dummy_input.to("cpu")
# Brief warmup: enough for caches, not long enough to throttle
for _ in range(warmup_rounds):
_ = m(x)
# Time individual forwards; take robust stats (median)
lat_ms = []
for _ in range(test_rounds):
t0 = time.perf_counter(); _ = m(x); t1 = time.perf_counter()
lat_ms.append((t1 - t0) * 1e3)
lat_ms = np.asarray(lat_ms, dtype=float)
p50 = float(np.median(lat_ms))
return {
"p50_ms": p50,
"p90_ms": float(np.percentile(lat_ms, 90)),
"mean_ms": float(lat_ms.mean()),
"std_ms": float(lat_ms.std()),
"throughput_ips": float(1000.0 / p50),
}
# %% ../nbs/00_benchmark.ipynb 13
@torch.inference_mode()
def get_model_macs(model, inputs) -> int:
args = (inputs,) if not isinstance(inputs, (tuple, list)) else tuple(inputs)
try:
return profile_macs(model, args)
except Exception:
try:
custom_ops = {
nnq.Conv2d: count_convNd,
nniq.ConvReLU2d: count_convNd,
nnq.Linear: count_linear,
nniq.LinearReLU: count_linear,
}
macs_val, _ = profile(model, inputs=args, custom_ops=custom_ops)
return macs_val
except Exception:
return 0
# %% ../nbs/00_benchmark.ipynb 16
@torch.inference_mode()
def evaluate_emissions(model, dummy_input, warmup_rounds=5, test_rounds=20):
device = torch.device("cpu")
model.eval()
model.to(device)
dummy_input = dummy_input.to(device)
# Warm up GPU
for _ in range(warmup_rounds):
_ = model(dummy_input)
# Measure Latency
tracker = OfflineEmissionsTracker(country_iso_code="USA")
tracker.start()
for _ in range(test_rounds):
_ = model(dummy_input)
tracker.stop()
total_emissions = tracker.final_emissions
total_energy_consumed = tracker.final_emissions_data.energy_consumed
# Calculate average emissions and energy consumption per inference
average_emissions_per_inference = total_emissions / test_rounds
average_energy_per_inference = total_energy_consumed / test_rounds
return average_emissions_per_inference, average_energy_per_inference
# %% ../nbs/00_benchmark.ipynb 18
@torch.inference_mode()
def benchmark(model, dummy_input):
# Model Size
print('disk size')
disk_size = get_model_size(model)
# CPU Speed
print('cpu speed')
base_stats = evaluate_cpu_speed_raw(model, dummy_input)
cpu_latency = base_stats["p50_ms"]
cpu_std_latency = base_stats["std_ms"]
cpu_throughput = base_stats["throughput_ips"]
# Model MACs and parameters with fallbacks
print('macs')
macs_str = "0.000G"
params_str = "0.000M"
try:
macs_val, params_val = profile(model, inputs=(dummy_input, ))
macs_str, params_str = clever_format([macs_val, params_val], "%.3f")
except Exception:
try:
macs_val = profile_macs(model, (dummy_input,))
macs_str = clever_format([macs_val], "%.3f")[0]
except Exception:
macs_str = "0.000G"
try:
params_val = sum(p.numel() for p in getattr(model, 'parameters', lambda: [])() if getattr(p, 'requires_grad', False))
# convert to M
params_str = f"{params_val/1e6:.3f}M"
except Exception:
params_str = "0.000M"
print('emissions')
# Emissions
avg_emissions, avg_energy = evaluate_emissions(model, dummy_input)
# Print results
try:
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {params_str} parameters")
except Exception:
pass
print(f"CPU Latency: {cpu_latency:.3f} ms (Β± {cpu_std_latency:.3f} ms)")
print(f"CPU Throughput: {cpu_throughput:.2f} inferences/sec")
print(f"Model MACs: {macs_str}")
print(f"Average Carbon Emissions per Inference: {avg_emissions*1e3:.6f} gCO2e")
print(f"Average Energy Consumption per Inference: {avg_energy*1e3:.6f} Wh")
return {
'disk_size': disk_size,
'num_parameters': params_str,
'cpu_latency': cpu_latency,
'cpu_throughput': cpu_throughput,
'macs': macs_str,
'avg_emissions': avg_emissions,
'avg_energy': avg_energy
}
def parse_metric_value(value_str):
"""Convert string values with units (M, G) to float"""
if isinstance(value_str, (int, float)):
return float(value_str)
value_str = str(value_str)
if 'G' in value_str:
return float(value_str.replace('G', '')) * 1000 # Convert G to M
elif 'M' in value_str:
return float(value_str.replace('M', '')) # Keep in M
elif 'K' in value_str:
return float(value_str.replace('K', '')) / 1000 # Convert K to M
else:
return float(value_str)
# Compression and visualization utilities (merged from Compressor)
class Quant:
def __init__(self, backend="x86"):
self.qconfig = get_default_qconfig_mapping(backend)
def quantize(self, model):
example_inputs = (torch.randn(1, 3, 224, 224),)
model_prepared = prepare_fx(model.eval(), self.qconfig, example_inputs)
return convert_fx(model_prepared)
def prune_model(input_model, sparsity, context, criteria):
# Accept either a path or an nn.Module
if isinstance(input_model, str):
model = torch.load(input_model, weights_only=False, map_location='cpu')
else:
model = input_model
model = model.eval()
model = model.to('cpu')
sp = Sparsifier(model, 'filter', context, criteria=eval(criteria))
sp.sparsify_model(sparsity)
sp._clean_buffers()
pr = Pruner(model, sparsity, context, criteria=eval(criteria))
pr.prune_model()
return pr.model
def quantize_model(model):
qu = Quant()
qu_model = qu.quantize(model)
return qu_model
def optimize_model(model, sparsity, context, criteria):
model = prune_model(model, sparsity, context, criteria)
model = quantize_model(model)
return model
def create_size_comparison_plot(before_results, after_results, metrics):
sns.set_style("darkgrid")
# Increase figure size height to accommodate labels better
fig = plt.figure(figsize=(12, 7), dpi=150)
fig.patch.set_alpha(0.0)
ax = plt.gca()
ax.patch.set_alpha(0.0)
bars = plt.bar(['Original', 'Compressed'],
[before_results, after_results],
color=['#FF6B00', '#FF9F1C'],
alpha=0.8,
width=0.6)
# Dynamic units per metric
unit_label_map = {
'Latency': 'Latency (ms)',
'Size': 'Size (MB)',
'MACs': 'MACs (GMAC)',
'Energy': 'Energy (mWh)',
'Emissions': 'Emissions (mgCO2e)'
}
def format_value(val, metric):
try:
fval = float(val)
except Exception:
fval = 0.0
if metric == 'Latency':
return f"{fval:.2f} ms"
if metric == 'Size':
return f"{fval:.2f} MB"
if metric == 'MACs':
return f"{fval:.3f} GMAC"
if metric == 'Energy':
return f"{fval:.3f} mWh"
if metric == 'Emissions':
return f"{fval:.3f} mgCO2e"
return f"{fval:.3f}"
# Annotate bars with values + units
for bar in bars:
height = bar.get_height()
offset = (height * 0.02) if height else 0.05
plt.text(bar.get_x() + bar.get_width()/2., height + offset,
format_value(height, metrics),
ha='center', va='bottom',
fontsize=15,
fontweight='bold',
color='white')
compression_ratio = ((before_results - after_results) / before_results) * 100 if before_results else 0
plt.title(f'Model Compression: {compression_ratio:.1f}%',
fontsize=18,
fontweight='bold',
pad=20,
color='white')
plt.xlabel('Model Version', fontsize=15, fontweight='bold', labelpad=10, color='white')
plt.ylabel(unit_label_map.get(metrics, metrics), fontsize=15, fontweight='bold', labelpad=10, color='white')
ax.grid(alpha=0.2, color='gray')
sns.despine()
# Use scientific notation for small Energy/Emissions values
if metrics in ('Energy', 'Emissions'):
ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 3))
try:
max_value = max(float(before_results), float(after_results))
except Exception:
max_value = float(before_results or after_results or 1)
plt.ylim(0, max_value * 1.3) # Increased upper limit
plt.yticks(np.linspace(0, max_value * 1.3, 10))
ax.tick_params(colors='white')
ax.tick_params(axis='x', colors='white', labelsize=16)
ax.tick_params(axis='y', colors='white', labelsize=15)
for tick_label in ax.get_xticklabels():
tick_label.set_fontweight('bold')
for spine in ax.spines.values():
spine.set_color('white')
ax.xaxis.label.set_color('white')
ax.yaxis.label.set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
if metrics not in ('Energy', 'Emissions'):
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.1f}'))
plt.tight_layout(pad=3.5) # Increased padding from 2.5 to 3.5
return fig
def benchmark_interface(model_name, compression_level, metrics):
import torchvision.models as models
# Cache base models by name
if model_name not in _MODEL_CACHE:
model_mapping = {
'ResNet18': models.resnet18(weights=None),
'ResNet50': models.resnet50(weights=None),
'MobileNetV2': models.mobilenet_v2(weights=None),
'EfficientNet-B0': models.efficientnet_b0(weights=None),
'VGG16': models.vgg16(weights=None),
}
base_model = model_mapping[model_name]
warmup_model(base_model)
_MODEL_CACHE[model_name] = base_model
model = _MODEL_CACHE[model_name]
dummy_input = torch.randn(1, 3, 224, 224)
# Benchmark before (convert to readable units for plotting)
if metrics == 'Latency':
base_stats = evaluate_cpu_speed_raw(model, dummy_input)
before_results = base_stats["p50_ms"]
elif metrics == 'Size':
before_results = get_model_size(model) / 1e6 # MB
elif metrics == 'MACs':
before_results = get_model_macs(model, dummy_input) / 1e9 # GMAC
elif metrics == 'Energy':
_, energy_kwh = evaluate_emissions(model, dummy_input)
before_results = energy_kwh * 1e6 # mWh
elif metrics == 'Emissions':
emissions_kg, _ = evaluate_emissions(model, dummy_input)
before_results = emissions_kg * 1e6 # mgCO2e
else:
raise ValueError(f"Invalid metric: {metrics}")
# Build or reuse compressed model for the selected compression level
cache_key = (model_name, compression_level)
if cache_key not in _COMPRESSED_CACHE:
sparsity = compression_values[compression_level]
model_for_pruning = copy.deepcopy(model)
comp_model = prune_model(model_for_pruning, sparsity, "local", "large_final")
_COMPRESSED_CACHE[cache_key] = comp_model
else:
comp_model = _COMPRESSED_CACHE[cache_key]
# Compute pre-quantization MACs if requested (more robust for tracing)
if metrics == 'MACs':
after_results = get_model_macs(comp_model, dummy_input) / 1e9 # GMAC
# Quantize lazily and cache the quantized variant too
q_cache_key = (model_name, compression_level, 'quant')
if q_cache_key not in _COMPRESSED_CACHE:
q_model = quantize_model(comp_model)
q_model.eval()
_COMPRESSED_CACHE[q_cache_key] = q_model
else:
q_model = _COMPRESSED_CACHE[q_cache_key]
if metrics == 'Latency':
base_stats = evaluate_cpu_speed_raw(q_model, dummy_input)
after_results = base_stats["p50_ms"]
elif metrics == 'Size':
after_results = get_model_size(q_model) / 1e6 # MB
elif metrics == 'MACs':
# already computed above (pre-quantization for better compatibility)
pass
elif metrics == 'Energy':
_, energy_kwh_after = evaluate_emissions(q_model, dummy_input)
after_results = energy_kwh_after * 1e6 # mWh
elif metrics == 'Emissions':
emissions_kg_after, _ = evaluate_emissions(q_model, dummy_input)
after_results = emissions_kg_after * 1e6 # mgCO2e
else:
raise ValueError(f"Invalid metric: {metrics}")
# Build plots
size_plot = create_size_comparison_plot(before_results, after_results, metrics)
return size_plot
available_models = [
'ResNet18',
'ResNet50',
'MobileNetV2',
'EfficientNet-B0',
'VGG16'
]
compression_values = {
'Mild 🐌': 25,
'Balanced 🐒': 50,
'Aggressive πŸ‡': 75,
'Extreme 🐎': 90
}
metrics = [
'Latency',
'Size',
'MACs',
'Energy',
'Emissions',
]
iface = gr.Interface(
fn=benchmark_interface,
inputs=[
gr.Dropdown(choices=available_models, label="Select Model", value='ResNet18'),
gr.Radio(choices=list(compression_values.keys()), label="Compression Level", value='Balanced 🐒'),
#gr.Radio(choices=list(target_device.keys()), label="Target Device", value='CPU'),
gr.Radio(choices=metrics, label="Comparison Metric", value='Latency'),
],
outputs=[
gr.Plot(label="Size Comparison") # Changed from gr.Image to gr.Plot
],
)
iface.launch()