File size: 9,210 Bytes
789eef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
import itertools
import json
import os
import os.path as osp
import time
from functools import partial
from pathlib import Path
from typing import Optional, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
import torch.utils.benchmark as benchmark
from mmengine.utils import mkdir_or_exist
# import torch_tensorrt
from mmpretrain import FeatureExtractor
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype
from torch.profiler import ProfilerActivity
from tqdm import tqdm
def _benchmark(model, input, model_name=""):
imgs = (
input["imgs"][0, ...].unsqueeze(0)
if model_name.lower() == "original"
else input["imgs"]
)
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
time_ = []
# g = torch.cuda.CUDAGraph()
# device = imgs.device
# imgs = imgs.cpu()
# rand = torch.randn(*imgs.shape, dtype=imgs.dtype, device=device)
# with torch.cuda.graph(g):
# with torch.no_grad():
# model(rand)
# rand.copy_(imgs)
# g.replay()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s), torch.no_grad():
for _ in range(10):
if torch.cuda.is_available():
torch.cuda.synchronize()
start_event.record()
model(imgs)
end_event.record()
if torch.cuda.is_available():
end_event.record()
torch.cuda.synchronize()
time_.append(start_event.elapsed_time(end_event))
torch.cuda.current_stream().wait_stream(s)
mean_time = np.mean(time_[1:]) / (imgs.shape[0])
print(f"For {model_name} model, ", flush=True)
print(f"avg time is {mean_time} ms", flush=True)
print(f"Total time is {sum(time_)} ms", flush=True)
print(f"Each trial time: {time_}", flush=True)
return mean_time
def _convert_batchnorm(module):
module_output = module
if isinstance(module, (torch.nn.SyncBatchNorm)):
module_output = torch.nn.BatchNorm2d(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if isinstance(module, (torch.nn.SiLU)):
module_output = torch.nn.ReLU(inplace=True)
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _demo_mm_inputs(input_shape):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
mm_inputs = {
"imgs": torch.FloatTensor(imgs),
}
return mm_inputs
def explain_model(model, inputs):
imgs = inputs["imgs"]
with torch.no_grad():
explanation = torch._dynamo.explain(model, imgs)
return explanation.graphs, explanation.graph_count, explanation.break_reasons
def compile_model(
model,
inputs,
output_file="compiled_model.pt",
max_batch_size=64,
dtype=torch.bfloat16,
):
imgs = inputs["imgs"]
modes = {"Deafult": "default", "RO": "reduce-overhead", "MA": "max-autotune"}
# modes = { "MA": "max-autotune"}
# modes = {"int8_dq": change_linear_weights_to_int8_dqtensors,}#{"int8_dq": change_linear_weights_to_int8_dqtensors,} #"int8_wo": change_linear_weights_to_int8_woqtensors,}# "int4": change_linear_weights_to_int4_woqtensors}
# modes = {"int8_int4": Int8DynActInt4WeightQuantizer(groupsize=128).quantize}
min_mean = float("inf")
best_mode = None
inputs["imgs"] = inputs["imgs"].to(dtype).cuda()
imgs = inputs["imgs"]
args = (imgs,)
kwargs = {}
dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size)
dynamic_shapes = {"inputs": {0: dynamic_batch}}
with torch.no_grad():
# model.forward = model._forward
exported_model = torch.export.export(
model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
for mode_str, mode in modes.items():
print(f"Compiling model with {mode_str} mode")
compiled_model = torch.compile(exported_model.module(), mode=mode)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s), torch.no_grad():
for i in range(3):
compiled_model(imgs)
torch.cuda.synchronize()
torch.cuda.current_stream().wait_stream(s)
mean = _benchmark(compiled_model, inputs, model_name=mode_str)
if mean < min_mean:
min_mean = mean
best_mode = mode_str
# inputs["imgs"] = inputs["imgs"].to(torch.bfloat16)
# model = m
print(f"Best compilation mode: {best_mode}")
torch.export.save(exported_model, output_file)
print(output_file)
def parse_args():
parser = argparse.ArgumentParser(description="MMSeg sparsify a model")
parser.add_argument("config", help="test config file path")
parser.add_argument("checkpoint", help="checkpoint file")
parser.add_argument(
"--shape",
type=int,
nargs="+",
default=[1024, 1024],
help="input image size (height, width)",
)
parser.add_argument(
"--output_dir", "--output-dir", type=str, help="input image directory"
)
parser.add_argument(
"--max-batch-size",
type=int,
default=64,
help="Maximum batch size for dynamic compile",
)
parser.add_argument(
"--explain-verbose", action="store_true", help="Explains the model compilation"
)
parser.add_argument(
"--force-compile",
action="store_true",
help="Force compile the model even if more than one cuda graphs are present",
)
parser.add_argument(
"--fp16", action="store_true", help="To enable fp16. Default is bf16"
)
args = parser.parse_args()
return args
def collate_wrapper(calib_dataloader_collate, *args, **kwargs):
# inputs = calib_dataloader_collate(*args, **kwargs)["inputs"]
return torch.stack(calib_dataloader_collate(*args, **kwargs)["inputs"], dim=0).to(
dtype=torch.float
)
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (64, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (
64,
3,
) + tuple(args.shape)
else:
raise ValueError("invalid input shape")
os.makedirs(args.output_dir, exist_ok=True)
checkpoint_basename = Path(args.checkpoint).stem
model = FeatureExtractor(model=args.config, pretrained=args.checkpoint).model
model.backbone.out_type = (
"featmap" ## removes cls_token and returns spatial feature maps.
)
model.eval()
max_batch_size = args.max_batch_size
input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:])
mm_inputs = _demo_mm_inputs(input_shape)
if torch.cuda.is_available():
model.cuda()
mm_inputs["imgs"] = mm_inputs["imgs"].cuda()
_benchmark(model, mm_inputs, "Original")
graphs, graph_counts, break_reasons = explain_model(model, mm_inputs)
if args.explain_verbose:
print(f"Graphs: {graphs}")
print(f"Graph Counts: {graph_counts}")
print(f"Reasons: {break_reasons}")
if not args.force_compile and graph_counts > 1:
print(f"Graphs are not fusable. Expected 1 graph. Found {graph_counts}")
return
dtype = torch.bfloat16 if not args.fp16 else torch.half
model.to(dtype)
mm_inputs["imgs"] = mm_inputs["imgs"].to(dtype)
save_path = os.path.join(
args.output_dir,
f"{checkpoint_basename}_{'float16' if dtype==torch.float16 else 'bfloat16'}.pt2",
)
compile_model(
model, mm_inputs, save_path, max_batch_size=max_batch_size, dtype=dtype
)
if __name__ == "__main__":
main()
|