# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import copy import itertools import json import os import os.path as osp import time from functools import partial from pathlib import Path from typing import Optional, Tuple import cv2 import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize import torch.utils.benchmark as benchmark from mmengine import Config from mmengine.fileio import dump from mmengine.model.utils import revert_sync_batchnorm from mmengine.registry import init_default_scope from mmengine.runner import get_state_dict, load_checkpoint, Runner, save_checkpoint from mmengine.utils import mkdir_or_exist # import torch_tensorrt # import torch_tensorrt.ts.ptq as ptq # from pytorch_quantization.tensor_quant import QuantDescriptor # from pytorch_quantization import quant_modules # # quant_modules.initialize() # from pytorch_quantization import nn as quant_nn # from pytorch_quantization import calib # import modelopt.torch.quantization as mtq # from modelopt.torch.quantization.utils import export_torch_mode # from modelopt.torch.quantization.nn import TensorQuantizer from mmseg.apis import init_model # from mmseg.models import build_segmentor from mmseg.registry import MODELS from pytorch2torchscript import pytorch2libtorch from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype from torch.profiler import ProfilerActivity from tqdm import tqdm def _benchmark(model, input, model_name=""): imgs = ( input["imgs"][0, ...].unsqueeze(0) if model_name.lower() == "original" else input["imgs"] ) if torch.cuda.is_available(): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) time_ = [] # g = torch.cuda.CUDAGraph() # device = imgs.device # imgs = imgs.cpu() # rand = torch.randn(*imgs.shape, dtype=imgs.dtype, device=device) # with torch.cuda.graph(g): # with torch.no_grad(): # model(rand) # rand.copy_(imgs) # g.replay() s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s), torch.no_grad(): for _ in range(10): if torch.cuda.is_available(): torch.cuda.synchronize() start_event.record() model(imgs) end_event.record() if torch.cuda.is_available(): end_event.record() torch.cuda.synchronize() time_.append(start_event.elapsed_time(end_event)) torch.cuda.current_stream().wait_stream(s) mean_time = np.mean(time_[1:]) / (imgs.shape[0]) print(f"For {model_name} model, ", flush=True) print(f"avg time is {mean_time} ms", flush=True) print(f"Total time is {sum(time_)} ms", flush=True) print(f"Each trial time: {time_}", flush=True) return mean_time def _convert_batchnorm(module): module_output = module if isinstance(module, (torch.nn.SyncBatchNorm)): module_output = torch.nn.BatchNorm2d( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, ) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked if isinstance(module, (torch.nn.SiLU)): module_output = torch.nn.ReLU(inplace=True) for name, child in module.named_children(): module_output.add_module(name, _convert_batchnorm(child)) del module return module_output def _demo_mm_inputs(input_shape, num_classes): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) if num_classes > 1: segs = rng.randint(low=0, high=num_classes - 1, size=(N, 1, H, W)).astype( np.uint8 ) else: segs = rng.uniform(0, 1, size=(N, 1, H, W)).astype(np.uint8) img_metas = [ { "img_shape": (H, W, C), "ori_shape": (H, W, C), "pad_shape": (H, W, C), "filename": ".png", "scale_factor": 1.0, "flip": False, } for _ in range(N) ] mm_inputs = { "imgs": torch.FloatTensor(imgs), "img_metas": img_metas, "gt_semantic_seg": torch.LongTensor(segs), } return mm_inputs def explain_model(model, inputs): imgs = inputs["imgs"] with torch.no_grad(): explanation = torch._dynamo.explain(model, imgs) return explanation.graphs, explanation.graph_count, explanation.break_reasons def fuse_model(model): fuse_modules = torch.ao.quantization.fuse_modules decode_convs = model.decode_head.conv_layers for idx in range(len(decode_convs)): if isinstance(decode_convs[idx], torch.nn.Conv2d): fuse_modules( decode_convs, [str(idx), str(idx + 1), str(idx + 2)], inplace=True ) model.decode_head.conv_layers = decode_convs decode_deconvs = model.decode_head.deconv_layers for idx in range(len(decode_deconvs)): if isinstance(decode_deconvs[idx], torch.nn.ConvTranspose2d): fuse_modules(decode_deconvs, [str(idx + 1), str(idx + 2)], inplace=True) model.decode_head.deconv_layers = decode_deconvs return model def compile_model( model, inputs, calib_dataloader, output_file="compiled_model.pt", max_batch_size=32, dtype=torch.bfloat16, ): imgs = inputs["imgs"] modes = {"Default": "default", "RO": "reduce-overhead", "MA": "max-autotune"} # modes = { "MA": "max-autotune"} # modes = {"int8_dq": change_linear_weights_to_int8_dqtensors,}#{"int8_dq": change_linear_weights_to_int8_dqtensors,} #"int8_wo": change_linear_weights_to_int8_woqtensors,}# "int4": change_linear_weights_to_int4_woqtensors} # modes = {"int8_int4": Int8DynActInt4WeightQuantizer(groupsize=128).quantize} min_mean = float("inf") best_mode = None if calib_dataloader: import torchao from torch.ao.quantization import quantize from torchao.quantization.quant_api import ( change_linear_weights_to_int4_woqtensors, change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int8_woqtensors, Int8DynActInt4WeightQuantizer, ) from torchao.quantization.smoothquant import ( smooth_fq_linear_to_inference, swap_linear_with_smooth_fq_linear, ) from torchao.utils import unwrap_tensor_subclass torch._dynamo.config.automatic_dynamic_shapes = False torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.use_mixed_mm = True swap_linear_with_smooth_fq_linear(model) print("Calibrating ...") model.eval() with torch.no_grad(): for batch in calib_dataloader: # model.zero_grad() model(batch.to(dtype).cuda()) print("Calibration done") smooth_fq_linear_to_inference(model) inputs["imgs"] = inputs["imgs"].to(dtype).cuda() imgs = inputs["imgs"] model.eval() args = (imgs,) kwargs = {} dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size) dynamic_shapes = {"inputs": {0: dynamic_batch}} with torch.no_grad(): # model.forward = model._forward for mode_str, mode in modes.items(): print(f"Compiling model with {mode_str} mode") exported_model = torch.export.export( model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes ) compiled_model = torch.compile(exported_model.module(), mode=mode) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s), torch.no_grad(): for i in range(3): compiled_model(imgs) torch.cuda.synchronize() torch.cuda.current_stream().wait_stream(s) mean = _benchmark(compiled_model, inputs, model_name=mode_str) if mean < min_mean: min_mean = mean best_mode = mode_str # inputs["imgs"] = inputs["imgs"].to(torch.bfloat16) # model = m print(f"Best compilation mode: {best_mode}") torch.export.save(exported_model, output_file) print(output_file) def run(test_loop) -> dict: """Launch test.""" test_loop.runner.call_hook("before_test") test_loop.runner.call_hook("before_test_epoch") test_loop.runner.model.eval() for idx, data_batch in enumerate(test_loop.dataloader): run_iter(test_loop, idx, data_batch) # compute metrics metrics = test_loop.evaluator.evaluate(len(test_loop.dataloader.dataset)) test_loop.runner.call_hook("after_test_epoch", metrics=metrics) test_loop.runner.call_hook("after_test") return metrics @torch.no_grad() def run_iter(test_loop, idx, data_batch) -> None: """Iterate one mini-batch. Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ test_loop.runner.call_hook("before_test_iter", batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement # with autocast(enabled=test_loop.fp16): # with torch.autocast(device_type=get_device(), dtype=torch.bfloat16): outputs = test_loop.runner.model.test_step(data_batch) test_loop.evaluator.process(data_samples=outputs, data_batch=data_batch) test_loop.runner.call_hook( "after_test_iter", batch_idx=idx, data_batch=data_batch, outputs=outputs ) def calib_loop(runner, model): """ Tensorrt quantization loop """ runner.model = runner.wrap_model(runner.cfg.get("model_wrapper_cfg"), model) runner.test() def collect_stats(model, data_loader, num_batches=100): """Feed data to the network and collect statistic""" # Enable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, (image) in tqdm(enumerate(data_loader), total=num_batches): model(image.to(torch.float).cuda()) if i >= num_batches: break # Disable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.enable_quant() module.disable_calib() else: module.enable() def compute_amax(model, **kwargs): # Load calib result for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: if isinstance(module._calibrator, calib.MaxCalibrator): module.load_calib_amax() else: module.load_calib_amax(**kwargs) print(f"{name:40}: {module}") model.cuda() def quantize_pytorch(model, calib_dataloader): quant_desc_input = QuantDescriptor(calib_method="histogram") quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) with torch.no_grad(): collect_stats(model, calib_dataloader) compute_amax(model, method="percentile", percentile=99.99) return model def convert_to_tensorrt( model, shape, num_classes, calibrate_loop=None, output_file="tensorrt_model.pt" ): # imgs = inputs["imgs"] shape = shape[1:] input_dynamic_batch_shape = (20,) # input_sig = (torch_tensorrt.Input( # min_shape=(input_dynamic_batch_shape[0], *shape), # opt_shape=(input_dynamic_batch_shape[1], *shape), # max_shape=(input_dynamic_batch_shape[2], *shape), # dtype=torch.bfloat16)) model.eval() min_mean = float("inf") best_mode = None best_model = None for bs in input_dynamic_batch_shape: print(f"Tensorrt model with batch size {bs}") inputs = _demo_mm_inputs((bs, *shape), num_classes) # inputs["imgs"] = inputs["imgs"].cuda() # imgs = inputs["imgs"] if calibrate_loop: # Other experimented quantization methods # quant_cfg = mtq.INT8_SMOOTHQUANT_CFG.copy() # print(quant_cfg) # quant_cfg["quant_cfg"]["*InstanceNorm*"] = {"enable": False} # quant_cfg["quant_cfg"]["*InstanceNorm*weight_quantizer"] = {"enable": False} # print(help(mtq.register)) # mtq.register(original_cls=torch.nn.InstanceNorm2d, quantized_cls=QuantizedInstanceNorm) # PTQ with in-place replacement to quantized modules # mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model = quantize_pytorch(model, calibrate_loop) # quant_nn.TensorQuantizer.use_fb_fake_quant = True calibrator = ptq.DataLoaderCalibrator( calibrate_loop, use_cache=False, algo_type=ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, device=torch.device("cuda:0"), ) quantized = True print("Quantized model") else: quantized = False compile_spec = { # "inputs": [torch_tensorrt.Input((bs, *shape))], # , dtype=(torch.int8 if quantized else torch.half) "inputs": [ torch_tensorrt.Input( (bs, *shape), dtype=torch.int8 if quantized else torch.half ) ], "enabled_precisions": (torch.int8,) if quantized else (torch.half,), # "optimization_level": 5, "truncate_long_and_double": True, "require_full_compilation": True, "allow_shape_tensors": True, # "debug": True, } if quantized: # model.to(torch.bfloat16) # model.half() compile_spec["calibrator"] = calibrator model.eval() inputs["imgs"] = inputs["imgs"].cuda().to(torch.bfloat16) imgs = inputs["imgs"] else: model.half() inputs["imgs"] = inputs["imgs"].half().cuda() imgs = inputs["imgs"] with torch.no_grad(): # with export_torch_mode(): model = torch.jit.trace(model, imgs) trt_model = torch_tensorrt.compile(model, ir="torchscript", **compile_spec) if quantized: inputs["imgs"] = inputs["imgs"].to(torch.int8) imgs = inputs["imgs"] s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s), torch.no_grad(): for i in range(3): trt_model(imgs) torch.cuda.current_stream().wait_stream(s) mean = _benchmark( trt_model, inputs, model_name=f"TensorRT with batch size {bs}" ) if mean < min_mean: min_mean = mean best_mode = bs best_model = copy.deepcopy(trt_model).cpu() del trt_model # print(model) print(f"Best batch size: {best_mode}, with avg time {min_mean} ms") # torch_tensorrt.save(best_model.cuda(), output_file, inputs=[imgs]) def parse_args(): parser = argparse.ArgumentParser(description="Sparsify a model") parser.add_argument("config", help="test config file path") parser.add_argument("checkpoint", help="checkpoint file") parser.add_argument( "--shape", type=int, nargs="+", default=[1024, 768], help="input image size (height, width)", ) parser.add_argument( "--output_dir", "--output-dir", type=str, help="input image directory" ) parser.add_argument( "--max-batch-size", type=int, default=32, help="Maximum batch size for dynamic compile", ) parser.add_argument( "--explain-verbose", action="store_true", help="Explains the model compilation" ) parser.add_argument( "--force-compile", action="store_true", help="Force compile the model even if more than one cuda graphs are present", ) parser.add_argument("--quant", action="store_true", help="To enable quantization") parser.add_argument( "--fp16", action="store_true", help="To enable fp16. Default is bf16" ) args = parser.parse_args() return args def collate_wrapper(calib_dataloader_collate, *args, **kwargs): # inputs = calib_dataloader_collate(*args, **kwargs)["inputs"] return torch.stack(calib_dataloader_collate(*args, **kwargs)["inputs"], dim=0).to( dtype=torch.float ) def main(): args = parse_args() if len(args.shape) == 1: input_shape = (16, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = ( 16, 3, ) + tuple(args.shape) else: raise ValueError("invalid input shape") os.makedirs(args.output_dir, exist_ok=True) checkpoint_basename = Path(args.checkpoint).stem cfg = Config.fromfile(args.config) init_default_scope(cfg.get("default_scope", "mmseg")) cfg.model.pretrained = None # build the model and load checkpoint cfg.model.train_cfg = None # work_dir is determined in this priority: CLI > segment in file > filename if args.output_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.output_dir elif cfg.get("work_dir", None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join( "./work_dirs", osp.splitext(osp.basename(args.config))[0] ) max_batch_size = args.max_batch_size input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:]) cfg.load_from = args.checkpoint # model = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get("test_cfg")) calib_dataloader = None if args.quant: runner = Runner.from_cfg(cfg) diff_rank_seed = runner._randomness_cfg.get("diff_rank_seed", False) calib_dataloader = runner.build_dataloader( cfg.get("test_dataloader"), seed=runner.seed, diff_rank_seed=diff_rank_seed ) calib_dataloader.collate_fn = partial( collate_wrapper, calib_dataloader.collate_fn ) model = init_model(args.config, args.checkpoint, device="cpu") model.eval() # convert SyncBN to BN model = revert_sync_batchnorm(model) # if args.checkpoint: # load_checkpoint(model, args.checkpoint, map_location="cpu") if isinstance(model.decode_head, torch.nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) if torch.cuda.is_available(): model.cuda() mm_inputs["imgs"] = mm_inputs["imgs"].cuda() dtype = torch.bfloat16 if not args.fp16 else torch.half _benchmark(model, mm_inputs, "Original") graphs, graph_counts, break_reasons = explain_model(model, mm_inputs) if args.explain_verbose: print(f"Graphs: {graphs}") print(f"Graph Counts: {graph_counts}") print(f"Reasons: {break_reasons}") if not args.force_compile and graph_counts > 1: print(f"Graphs are not fusable. Expected 1 graph. Found {graph_counts}") return model.to(dtype) mm_inputs["imgs"] = mm_inputs["imgs"].to(dtype) save_path = os.path.join( args.output_dir, f"{checkpoint_basename}_{'float16' if dtype==torch.float16 else 'bfloat16'}.pt2", ) compile_model( model, mm_inputs, None, save_path, max_batch_size=max_batch_size, dtype=dtype ) # Tensorrt disabled for now # save_path = os.path.join(args.output_dir, f"{checkpoint_basename}_trt.ep") # convert_to_tensorrt(model, input_shape, num_classes, partial(calib_loop, runner), save_path) # print(model) if __name__ == "__main__": main()