| """ |
| Post-export optimization: quantization, TensorRT, and deployment helpers. |
| """ |
|
|
| import os |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def quantize_model(model: nn.Module, calibration_data=None, |
| method: str = 'dynamic') -> nn.Module: |
| """ |
| Quantize model for faster CPU inference. |
| |
| Args: |
| model: PyTorch model |
| calibration_data: DataLoader for static quantization calibration |
| method: 'dynamic' (no calibration needed) or 'static' |
| |
| Returns: |
| Quantized model |
| """ |
| model.eval().cpu() |
|
|
| if method == 'dynamic': |
| quantized = torch.quantization.quantize_dynamic( |
| model, |
| {nn.Linear, nn.Conv2d}, |
| dtype=torch.qint8, |
| ) |
| print("Dynamic quantization complete") |
| return quantized |
|
|
| elif method == 'static': |
| if calibration_data is None: |
| raise ValueError("Static quantization requires calibration_data") |
|
|
| model.qconfig = torch.quantization.get_default_qconfig('fbgemm') |
| prepared = torch.quantization.prepare(model, inplace=False) |
|
|
| |
| with torch.no_grad(): |
| for images, _ in calibration_data: |
| prepared(images) |
|
|
| quantized = torch.quantization.convert(prepared, inplace=False) |
| print("Static quantization complete") |
| return quantized |
|
|
| else: |
| raise ValueError(f"Unknown quantization method: {method}") |
|
|
|
|
| def benchmark_deployment(onnx_path: str, input_size: int = 640, |
| num_runs: int = 100) -> dict: |
| """ |
| Benchmark ONNX model inference speed. |
| |
| Returns dict with latency and throughput stats. |
| """ |
| import time |
| import numpy as np |
|
|
| try: |
| import onnxruntime as ort |
| except ImportError: |
| return {"error": "onnxruntime not installed"} |
|
|
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| session = ort.InferenceSession(onnx_path, providers=providers) |
|
|
| dummy = np.random.randn(1, 3, input_size, input_size).astype(np.float32) |
| input_name = session.get_inputs()[0].name |
|
|
| |
| for _ in range(20): |
| session.run(None, {input_name: dummy}) |
|
|
| |
| latencies = [] |
| for _ in range(num_runs): |
| t0 = time.perf_counter() |
| session.run(None, {input_name: dummy}) |
| latencies.append((time.perf_counter() - t0) * 1000) |
|
|
| latencies = np.array(latencies) |
|
|
| return { |
| 'onnx_path': onnx_path, |
| 'input_size': input_size, |
| 'latency_p50_ms': np.percentile(latencies, 50), |
| 'latency_p95_ms': np.percentile(latencies, 95), |
| 'fps': 1000 / np.mean(latencies), |
| 'provider': session.get_providers()[0], |
| } |
|
|
|
|
| |
|
|
| TENSORRT_GUIDE = """ |
| # TensorRT Optimization Guide for SCRFD |
| |
| ## Prerequisites |
| pip install tensorrt # or use NVIDIA TensorRT container |
| |
| ## Convert ONNX to TensorRT Engine |
| |
| ### FP16 (recommended for production GPU deployment) |
| trtexec --onnx=scrfd_34g.onnx \\ |
| --saveEngine=scrfd_34g_fp16.engine \\ |
| --fp16 \\ |
| --workspace=4096 \\ |
| --minShapes=input:1x3x640x640 \\ |
| --optShapes=input:1x3x640x640 \\ |
| --maxShapes=input:8x3x640x640 |
| |
| ### INT8 (fastest, requires calibration) |
| trtexec --onnx=scrfd_34g.onnx \\ |
| --saveEngine=scrfd_34g_int8.engine \\ |
| --int8 \\ |
| --calib=calibration_cache.bin \\ |
| --workspace=4096 |
| |
| ### Dynamic batch size |
| trtexec --onnx=scrfd_34g.onnx \\ |
| --saveEngine=scrfd_34g_dynamic.engine \\ |
| --fp16 \\ |
| --minShapes=input:1x3x640x640 \\ |
| --optShapes=input:4x3x640x640 \\ |
| --maxShapes=input:16x3x640x640 |
| |
| ## Expected Speedups (V100) |
| | Model | PyTorch FP32 | ONNX Runtime | TensorRT FP16 | TensorRT INT8 | |
| |------------|-------------|-------------|----------------|---------------| |
| | SCRFD-34G | ~80 FPS | ~100 FPS | ~200 FPS | ~350 FPS | |
| | SCRFD-2.5G | ~400 FPS | ~500 FPS | ~800 FPS | ~1200 FPS | |
| | SCRFD-0.5G | ~1000 FPS | ~1200 FPS | ~2000 FPS | ~3000 FPS | |
| |
| ## INT8 Calibration |
| Use the calibration script: |
| python scripts/tensorrt_calibrate.py \\ |
| --onnx scrfd_34g.onnx \\ |
| --data-root data/wider_face \\ |
| --num-images 500 |
| |
| ## Deployment with Triton Inference Server |
| See configs/triton/ for model repository configuration. |
| """ |
|
|