|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import copy |
|
|
import itertools |
|
|
import json |
|
|
import os |
|
|
import os.path as osp |
|
|
import time |
|
|
from functools import partial |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
import torch.nn as nn |
|
|
import torch.nn.utils.parametrize as parametrize |
|
|
import torch.utils.benchmark as benchmark |
|
|
from mmengine import Config |
|
|
from mmengine.fileio import dump |
|
|
from mmengine.model.utils import revert_sync_batchnorm |
|
|
from mmengine.registry import init_default_scope |
|
|
from mmengine.runner import get_state_dict, load_checkpoint, Runner, save_checkpoint |
|
|
from mmengine.utils import mkdir_or_exist |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mmseg.apis import init_model |
|
|
|
|
|
|
|
|
from mmseg.registry import MODELS |
|
|
|
|
|
from pytorch2torchscript import pytorch2libtorch |
|
|
|
|
|
from torch._dynamo import is_compiling as dynamo_is_compiling |
|
|
from torch._higher_order_ops.out_dtype import out_dtype |
|
|
from torch.profiler import ProfilerActivity |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def _benchmark(model, input, model_name=""): |
|
|
imgs = ( |
|
|
input["imgs"][0, ...].unsqueeze(0) |
|
|
if model_name.lower() == "original" |
|
|
else input["imgs"] |
|
|
) |
|
|
if torch.cuda.is_available(): |
|
|
start_event = torch.cuda.Event(enable_timing=True) |
|
|
end_event = torch.cuda.Event(enable_timing=True) |
|
|
|
|
|
time_ = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(s), torch.no_grad(): |
|
|
for _ in range(10): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
start_event.record() |
|
|
model(imgs) |
|
|
end_event.record() |
|
|
if torch.cuda.is_available(): |
|
|
end_event.record() |
|
|
torch.cuda.synchronize() |
|
|
time_.append(start_event.elapsed_time(end_event)) |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
mean_time = np.mean(time_[1:]) / (imgs.shape[0]) |
|
|
print(f"For {model_name} model, ", flush=True) |
|
|
print(f"avg time is {mean_time} ms", flush=True) |
|
|
print(f"Total time is {sum(time_)} ms", flush=True) |
|
|
print(f"Each trial time: {time_}", flush=True) |
|
|
return mean_time |
|
|
|
|
|
|
|
|
def _convert_batchnorm(module): |
|
|
module_output = module |
|
|
if isinstance(module, (torch.nn.SyncBatchNorm)): |
|
|
module_output = torch.nn.BatchNorm2d( |
|
|
module.num_features, |
|
|
module.eps, |
|
|
module.momentum, |
|
|
module.affine, |
|
|
module.track_running_stats, |
|
|
) |
|
|
if module.affine: |
|
|
module_output.weight.data = module.weight.data.clone().detach() |
|
|
module_output.bias.data = module.bias.data.clone().detach() |
|
|
|
|
|
module_output.weight.requires_grad = module.weight.requires_grad |
|
|
module_output.bias.requires_grad = module.bias.requires_grad |
|
|
module_output.running_mean = module.running_mean |
|
|
module_output.running_var = module.running_var |
|
|
module_output.num_batches_tracked = module.num_batches_tracked |
|
|
if isinstance(module, (torch.nn.SiLU)): |
|
|
module_output = torch.nn.ReLU(inplace=True) |
|
|
for name, child in module.named_children(): |
|
|
module_output.add_module(name, _convert_batchnorm(child)) |
|
|
del module |
|
|
return module_output |
|
|
|
|
|
|
|
|
def _demo_mm_inputs(input_shape, num_classes): |
|
|
"""Create a superset of inputs needed to run test or train batches. |
|
|
|
|
|
Args: |
|
|
input_shape (tuple): |
|
|
input batch dimensions |
|
|
num_classes (int): |
|
|
number of semantic classes |
|
|
""" |
|
|
(N, C, H, W) = input_shape |
|
|
rng = np.random.RandomState(0) |
|
|
imgs = rng.rand(*input_shape) |
|
|
if num_classes > 1: |
|
|
segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype( |
|
|
np.uint8 |
|
|
) |
|
|
else: |
|
|
segs = rng.uniform(0, 1, size=(N, 1, H, W)).astype(np.uint8) |
|
|
img_metas = [ |
|
|
{ |
|
|
"img_shape": (H, W, C), |
|
|
"ori_shape": (H, W, C), |
|
|
"pad_shape": (H, W, C), |
|
|
"filename": "<demo>.png", |
|
|
"scale_factor": 1.0, |
|
|
"flip": False, |
|
|
} |
|
|
for _ in range(N) |
|
|
] |
|
|
mm_inputs = { |
|
|
"imgs": torch.FloatTensor(imgs), |
|
|
"img_metas": img_metas, |
|
|
"gt_semantic_seg": torch.LongTensor(segs), |
|
|
} |
|
|
return mm_inputs |
|
|
|
|
|
|
|
|
def explain_model(model, inputs): |
|
|
imgs = inputs["imgs"] |
|
|
with torch.no_grad(): |
|
|
explanation = torch._dynamo.explain(model, imgs) |
|
|
return explanation.graphs, explanation.graph_count, explanation.break_reasons |
|
|
|
|
|
|
|
|
def fuse_model(model): |
|
|
fuse_modules = torch.ao.quantization.fuse_modules |
|
|
decode_convs = model.decode_head.conv_layers |
|
|
for idx in range(len(decode_convs)): |
|
|
if isinstance(decode_convs[idx], torch.nn.Conv2d): |
|
|
fuse_modules( |
|
|
decode_convs, [str(idx), str(idx + 1), str(idx + 2)], inplace=True |
|
|
) |
|
|
model.decode_head.conv_layers = decode_convs |
|
|
decode_deconvs = model.decode_head.deconv_layers |
|
|
for idx in range(len(decode_deconvs)): |
|
|
if isinstance(decode_deconvs[idx], torch.nn.ConvTranspose2d): |
|
|
fuse_modules(decode_deconvs, [str(idx + 1), str(idx + 2)], inplace=True) |
|
|
model.decode_head.deconv_layers = decode_deconvs |
|
|
return model |
|
|
|
|
|
|
|
|
def compile_model( |
|
|
model, |
|
|
inputs, |
|
|
calib_dataloader, |
|
|
output_file="compiled_model.pt", |
|
|
max_batch_size=32, |
|
|
dtype=torch.bfloat16, |
|
|
): |
|
|
|
|
|
imgs = inputs["imgs"] |
|
|
modes = {"Default": "default", "RO": "reduce-overhead", "MA": "max-autotune"} |
|
|
|
|
|
|
|
|
|
|
|
min_mean = float("inf") |
|
|
best_mode = None |
|
|
|
|
|
if calib_dataloader: |
|
|
import torchao |
|
|
from torch.ao.quantization import quantize |
|
|
from torchao.quantization.quant_api import ( |
|
|
change_linear_weights_to_int4_woqtensors, |
|
|
change_linear_weights_to_int8_dqtensors, |
|
|
change_linear_weights_to_int8_woqtensors, |
|
|
Int8DynActInt4WeightQuantizer, |
|
|
) |
|
|
from torchao.quantization.smoothquant import ( |
|
|
smooth_fq_linear_to_inference, |
|
|
swap_linear_with_smooth_fq_linear, |
|
|
) |
|
|
from torchao.utils import unwrap_tensor_subclass |
|
|
|
|
|
torch._dynamo.config.automatic_dynamic_shapes = False |
|
|
torch._inductor.config.force_fuse_int_mm_with_mul = True |
|
|
torch._inductor.config.use_mixed_mm = True |
|
|
swap_linear_with_smooth_fq_linear(model) |
|
|
print("Calibrating ...") |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for batch in calib_dataloader: |
|
|
|
|
|
model(batch.to(dtype).cuda()) |
|
|
print("Calibration done") |
|
|
smooth_fq_linear_to_inference(model) |
|
|
|
|
|
inputs["imgs"] = inputs["imgs"].to(dtype).cuda() |
|
|
imgs = inputs["imgs"] |
|
|
model.eval() |
|
|
args = (imgs,) |
|
|
kwargs = {} |
|
|
dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size) |
|
|
dynamic_shapes = {"inputs": {0: dynamic_batch}} |
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
for mode_str, mode in modes.items(): |
|
|
print(f"Compiling model with {mode_str} mode") |
|
|
exported_model = torch.export.export( |
|
|
model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes |
|
|
) |
|
|
compiled_model = torch.compile(exported_model.module(), mode=mode) |
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(s), torch.no_grad(): |
|
|
for i in range(3): |
|
|
compiled_model(imgs) |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
mean = _benchmark(compiled_model, inputs, model_name=mode_str) |
|
|
if mean < min_mean: |
|
|
min_mean = mean |
|
|
best_mode = mode_str |
|
|
|
|
|
|
|
|
print(f"Best compilation mode: {best_mode}") |
|
|
torch.export.save(exported_model, output_file) |
|
|
print(output_file) |
|
|
|
|
|
|
|
|
def run(test_loop) -> dict: |
|
|
"""Launch test.""" |
|
|
test_loop.runner.call_hook("before_test") |
|
|
test_loop.runner.call_hook("before_test_epoch") |
|
|
test_loop.runner.model.eval() |
|
|
for idx, data_batch in enumerate(test_loop.dataloader): |
|
|
run_iter(test_loop, idx, data_batch) |
|
|
|
|
|
|
|
|
metrics = test_loop.evaluator.evaluate(len(test_loop.dataloader.dataset)) |
|
|
test_loop.runner.call_hook("after_test_epoch", metrics=metrics) |
|
|
test_loop.runner.call_hook("after_test") |
|
|
return metrics |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def run_iter(test_loop, idx, data_batch) -> None: |
|
|
"""Iterate one mini-batch. |
|
|
|
|
|
Args: |
|
|
data_batch (Sequence[dict]): Batch of data from dataloader. |
|
|
""" |
|
|
test_loop.runner.call_hook("before_test_iter", batch_idx=idx, data_batch=data_batch) |
|
|
|
|
|
|
|
|
|
|
|
outputs = test_loop.runner.model.test_step(data_batch) |
|
|
test_loop.evaluator.process(data_samples=outputs, data_batch=data_batch) |
|
|
test_loop.runner.call_hook( |
|
|
"after_test_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs |
|
|
) |
|
|
|
|
|
|
|
|
def calib_loop(runner, model): |
|
|
""" |
|
|
Tensorrt quantization loop |
|
|
""" |
|
|
runner.model = runner.wrap_model(runner.cfg.get("model_wrapper_cfg"), model) |
|
|
runner.test() |
|
|
|
|
|
|
|
|
def collect_stats(model, data_loader, num_batches=100): |
|
|
"""Feed data to the network and collect statistic""" |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
|
if module._calibrator is not None: |
|
|
module.disable_quant() |
|
|
module.enable_calib() |
|
|
else: |
|
|
module.disable() |
|
|
|
|
|
for i, (image) in tqdm(enumerate(data_loader), total=num_batches): |
|
|
model(image.to(torch.float).cuda()) |
|
|
if i >= num_batches: |
|
|
break |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
|
if module._calibrator is not None: |
|
|
module.enable_quant() |
|
|
module.disable_calib() |
|
|
else: |
|
|
module.enable() |
|
|
|
|
|
|
|
|
def compute_amax(model, **kwargs): |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
|
if module._calibrator is not None: |
|
|
if isinstance(module._calibrator, calib.MaxCalibrator): |
|
|
module.load_calib_amax() |
|
|
else: |
|
|
module.load_calib_amax(**kwargs) |
|
|
print(f"{name:40}: {module}") |
|
|
model.cuda() |
|
|
|
|
|
|
|
|
def quantize_pytorch(model, calib_dataloader): |
|
|
quant_desc_input = QuantDescriptor(calib_method="histogram") |
|
|
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) |
|
|
quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) |
|
|
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) |
|
|
with torch.no_grad(): |
|
|
collect_stats(model, calib_dataloader) |
|
|
compute_amax(model, method="percentile", percentile=99.99) |
|
|
return model |
|
|
|
|
|
|
|
|
def convert_to_tensorrt( |
|
|
model, shape, num_classes, calibrate_loop=None, output_file="tensorrt_model.pt" |
|
|
): |
|
|
|
|
|
shape = shape[1:] |
|
|
input_dynamic_batch_shape = (20,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
min_mean = float("inf") |
|
|
best_mode = None |
|
|
best_model = None |
|
|
for bs in input_dynamic_batch_shape: |
|
|
print(f"Tensorrt model with batch size {bs}") |
|
|
inputs = _demo_mm_inputs((bs, *shape), num_classes) |
|
|
|
|
|
|
|
|
if calibrate_loop: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calibrator = ptq.DataLoaderCalibrator( |
|
|
calibrate_loop, |
|
|
use_cache=False, |
|
|
algo_type=ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, |
|
|
device=torch.device("cuda:0"), |
|
|
) |
|
|
quantized = True |
|
|
print("Quantized model") |
|
|
else: |
|
|
quantized = False |
|
|
|
|
|
compile_spec = { |
|
|
|
|
|
"inputs": [ |
|
|
torch_tensorrt.Input( |
|
|
(bs, *shape), dtype=torch.int8 if quantized else torch.half |
|
|
) |
|
|
], |
|
|
"enabled_precisions": (torch.int8,) if quantized else (torch.half,), |
|
|
|
|
|
"truncate_long_and_double": True, |
|
|
"require_full_compilation": True, |
|
|
"allow_shape_tensors": True, |
|
|
|
|
|
} |
|
|
if quantized: |
|
|
|
|
|
|
|
|
compile_spec["calibrator"] = calibrator |
|
|
model.eval() |
|
|
inputs["imgs"] = inputs["imgs"].cuda().to(torch.bfloat16) |
|
|
imgs = inputs["imgs"] |
|
|
else: |
|
|
model.half() |
|
|
inputs["imgs"] = inputs["imgs"].half().cuda() |
|
|
imgs = inputs["imgs"] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
model = torch.jit.trace(model, imgs) |
|
|
trt_model = torch_tensorrt.compile(model, ir="torchscript", **compile_spec) |
|
|
if quantized: |
|
|
inputs["imgs"] = inputs["imgs"].to(torch.int8) |
|
|
imgs = inputs["imgs"] |
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(s), torch.no_grad(): |
|
|
for i in range(3): |
|
|
trt_model(imgs) |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
mean = _benchmark( |
|
|
trt_model, inputs, model_name=f"TensorRT with batch size {bs}" |
|
|
) |
|
|
if mean < min_mean: |
|
|
min_mean = mean |
|
|
best_mode = bs |
|
|
best_model = copy.deepcopy(trt_model).cpu() |
|
|
del trt_model |
|
|
|
|
|
print(f"Best batch size: {best_mode}, with avg time {min_mean} ms") |
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Sparsify a model") |
|
|
parser.add_argument("config", help="test config file path") |
|
|
parser.add_argument("checkpoint", help="checkpoint file") |
|
|
parser.add_argument( |
|
|
"--shape", |
|
|
type=int, |
|
|
nargs="+", |
|
|
default=[1024, 768], |
|
|
help="input image size (height, width)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", "--output-dir", type=str, help="input image directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-batch-size", |
|
|
type=int, |
|
|
default=32, |
|
|
help="Maximum batch size for dynamic compile", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--explain-verbose", action="store_true", help="Explains the model compilation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--force-compile", |
|
|
action="store_true", |
|
|
help="Force compile the model even if more than one cuda graphs are present", |
|
|
) |
|
|
parser.add_argument("--quant", action="store_true", help="To enable quantization") |
|
|
parser.add_argument( |
|
|
"--fp16", action="store_true", help="To enable fp16. Default is bf16" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def collate_wrapper(calib_dataloader_collate, *args, **kwargs): |
|
|
|
|
|
return torch.stack(calib_dataloader_collate(*args, **kwargs)["inputs"], dim=0).to( |
|
|
dtype=torch.float |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
if len(args.shape) == 1: |
|
|
input_shape = (16, 3, args.shape[0], args.shape[0]) |
|
|
elif len(args.shape) == 2: |
|
|
input_shape = ( |
|
|
16, |
|
|
3, |
|
|
) + tuple(args.shape) |
|
|
else: |
|
|
raise ValueError("invalid input shape") |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
checkpoint_basename = Path(args.checkpoint).stem |
|
|
|
|
|
cfg = Config.fromfile(args.config) |
|
|
init_default_scope(cfg.get("default_scope", "mmseg")) |
|
|
cfg.model.pretrained = None |
|
|
|
|
|
|
|
|
cfg.model.train_cfg = None |
|
|
|
|
|
|
|
|
if args.output_dir is not None: |
|
|
|
|
|
cfg.work_dir = args.output_dir |
|
|
elif cfg.get("work_dir", None) is None: |
|
|
|
|
|
cfg.work_dir = osp.join( |
|
|
"./work_dirs", osp.splitext(osp.basename(args.config))[0] |
|
|
) |
|
|
|
|
|
max_batch_size = args.max_batch_size |
|
|
input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:]) |
|
|
cfg.load_from = args.checkpoint |
|
|
|
|
|
calib_dataloader = None |
|
|
if args.quant: |
|
|
runner = Runner.from_cfg(cfg) |
|
|
diff_rank_seed = runner._randomness_cfg.get("diff_rank_seed", False) |
|
|
calib_dataloader = runner.build_dataloader( |
|
|
cfg.get("test_dataloader"), seed=runner.seed, diff_rank_seed=diff_rank_seed |
|
|
) |
|
|
|
|
|
calib_dataloader.collate_fn = partial( |
|
|
collate_wrapper, calib_dataloader.collate_fn |
|
|
) |
|
|
model = init_model(args.config, args.checkpoint, device="cpu") |
|
|
model.eval() |
|
|
|
|
|
model = revert_sync_batchnorm(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(model.decode_head, torch.nn.ModuleList): |
|
|
num_classes = model.decode_head[-1].num_classes |
|
|
else: |
|
|
num_classes = model.decode_head.num_classes |
|
|
|
|
|
mm_inputs = _demo_mm_inputs(input_shape, num_classes) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model.cuda() |
|
|
mm_inputs["imgs"] = mm_inputs["imgs"].cuda() |
|
|
|
|
|
dtype = torch.bfloat16 if not args.fp16 else torch.half |
|
|
|
|
|
_benchmark(model, mm_inputs, "Original") |
|
|
|
|
|
graphs, graph_counts, break_reasons = explain_model(model, mm_inputs) |
|
|
|
|
|
if args.explain_verbose: |
|
|
print(f"Graphs: {graphs}") |
|
|
print(f"Graph Counts: {graph_counts}") |
|
|
print(f"Reasons: {break_reasons}") |
|
|
|
|
|
if not args.force_compile and graph_counts > 1: |
|
|
print(f"Graphs are not fusable. Expected 1 graph. Found {graph_counts}") |
|
|
return |
|
|
|
|
|
model.to(dtype) |
|
|
mm_inputs["imgs"] = mm_inputs["imgs"].to(dtype) |
|
|
save_path = os.path.join( |
|
|
args.output_dir, |
|
|
f"{checkpoint_basename}_{'float16' if dtype==torch.float16 else 'bfloat16'}.pt2", |
|
|
) |
|
|
compile_model( |
|
|
model, mm_inputs, None, save_path, max_batch_size=max_batch_size, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|