|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import copy |
|
|
import itertools |
|
|
import json |
|
|
import os |
|
|
import os.path as osp |
|
|
import time |
|
|
from functools import partial |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
import torch.nn as nn |
|
|
import torch.nn.utils.parametrize as parametrize |
|
|
import torch.utils.benchmark as benchmark |
|
|
from mmengine.utils import mkdir_or_exist |
|
|
|
|
|
|
|
|
from mmpretrain import FeatureExtractor |
|
|
|
|
|
from torch._dynamo import is_compiling as dynamo_is_compiling |
|
|
from torch._higher_order_ops.out_dtype import out_dtype |
|
|
from torch.profiler import ProfilerActivity |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def _benchmark(model, input, model_name=""): |
|
|
imgs = ( |
|
|
input["imgs"][0, ...].unsqueeze(0) |
|
|
if model_name.lower() == "original" |
|
|
else input["imgs"] |
|
|
) |
|
|
if torch.cuda.is_available(): |
|
|
start_event = torch.cuda.Event(enable_timing=True) |
|
|
end_event = torch.cuda.Event(enable_timing=True) |
|
|
|
|
|
time_ = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(s), torch.no_grad(): |
|
|
for _ in range(10): |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
start_event.record() |
|
|
model(imgs) |
|
|
end_event.record() |
|
|
if torch.cuda.is_available(): |
|
|
end_event.record() |
|
|
torch.cuda.synchronize() |
|
|
time_.append(start_event.elapsed_time(end_event)) |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
mean_time = np.mean(time_[1:]) / (imgs.shape[0]) |
|
|
print(f"For {model_name} model, ", flush=True) |
|
|
print(f"avg time is {mean_time} ms", flush=True) |
|
|
print(f"Total time is {sum(time_)} ms", flush=True) |
|
|
print(f"Each trial time: {time_}", flush=True) |
|
|
return mean_time |
|
|
|
|
|
|
|
|
def _convert_batchnorm(module): |
|
|
module_output = module |
|
|
if isinstance(module, (torch.nn.SyncBatchNorm)): |
|
|
module_output = torch.nn.BatchNorm2d( |
|
|
module.num_features, |
|
|
module.eps, |
|
|
module.momentum, |
|
|
module.affine, |
|
|
module.track_running_stats, |
|
|
) |
|
|
if module.affine: |
|
|
module_output.weight.data = module.weight.data.clone().detach() |
|
|
module_output.bias.data = module.bias.data.clone().detach() |
|
|
|
|
|
module_output.weight.requires_grad = module.weight.requires_grad |
|
|
module_output.bias.requires_grad = module.bias.requires_grad |
|
|
module_output.running_mean = module.running_mean |
|
|
module_output.running_var = module.running_var |
|
|
module_output.num_batches_tracked = module.num_batches_tracked |
|
|
if isinstance(module, (torch.nn.SiLU)): |
|
|
module_output = torch.nn.ReLU(inplace=True) |
|
|
for name, child in module.named_children(): |
|
|
module_output.add_module(name, _convert_batchnorm(child)) |
|
|
del module |
|
|
return module_output |
|
|
|
|
|
|
|
|
def _demo_mm_inputs(input_shape): |
|
|
"""Create a superset of inputs needed to run test or train batches. |
|
|
|
|
|
Args: |
|
|
input_shape (tuple): |
|
|
input batch dimensions |
|
|
num_classes (int): |
|
|
number of semantic classes |
|
|
""" |
|
|
(N, C, H, W) = input_shape |
|
|
rng = np.random.RandomState(0) |
|
|
imgs = rng.rand(*input_shape) |
|
|
|
|
|
mm_inputs = { |
|
|
"imgs": torch.FloatTensor(imgs), |
|
|
} |
|
|
return mm_inputs |
|
|
|
|
|
|
|
|
def explain_model(model, inputs): |
|
|
imgs = inputs["imgs"] |
|
|
with torch.no_grad(): |
|
|
explanation = torch._dynamo.explain(model, imgs) |
|
|
return explanation.graphs, explanation.graph_count, explanation.break_reasons |
|
|
|
|
|
|
|
|
def compile_model( |
|
|
model, |
|
|
inputs, |
|
|
output_file="compiled_model.pt", |
|
|
max_batch_size=64, |
|
|
dtype=torch.bfloat16, |
|
|
): |
|
|
|
|
|
imgs = inputs["imgs"] |
|
|
modes = {"Deafult": "default", "RO": "reduce-overhead", "MA": "max-autotune"} |
|
|
|
|
|
|
|
|
|
|
|
min_mean = float("inf") |
|
|
best_mode = None |
|
|
|
|
|
inputs["imgs"] = inputs["imgs"].to(dtype).cuda() |
|
|
imgs = inputs["imgs"] |
|
|
args = (imgs,) |
|
|
kwargs = {} |
|
|
dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size) |
|
|
dynamic_shapes = {"inputs": {0: dynamic_batch}} |
|
|
with torch.no_grad(): |
|
|
|
|
|
exported_model = torch.export.export( |
|
|
model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes |
|
|
) |
|
|
for mode_str, mode in modes.items(): |
|
|
print(f"Compiling model with {mode_str} mode") |
|
|
compiled_model = torch.compile(exported_model.module(), mode=mode) |
|
|
s = torch.cuda.Stream() |
|
|
s.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(s), torch.no_grad(): |
|
|
for i in range(3): |
|
|
compiled_model(imgs) |
|
|
torch.cuda.synchronize() |
|
|
torch.cuda.current_stream().wait_stream(s) |
|
|
mean = _benchmark(compiled_model, inputs, model_name=mode_str) |
|
|
if mean < min_mean: |
|
|
min_mean = mean |
|
|
best_mode = mode_str |
|
|
|
|
|
|
|
|
print(f"Best compilation mode: {best_mode}") |
|
|
torch.export.save(exported_model, output_file) |
|
|
print(output_file) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="MMSeg sparsify a model") |
|
|
parser.add_argument("config", help="test config file path") |
|
|
parser.add_argument("checkpoint", help="checkpoint file") |
|
|
parser.add_argument( |
|
|
"--shape", |
|
|
type=int, |
|
|
nargs="+", |
|
|
default=[1024, 1024], |
|
|
help="input image size (height, width)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", "--output-dir", type=str, help="input image directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-batch-size", |
|
|
type=int, |
|
|
default=64, |
|
|
help="Maximum batch size for dynamic compile", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--explain-verbose", action="store_true", help="Explains the model compilation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--force-compile", |
|
|
action="store_true", |
|
|
help="Force compile the model even if more than one cuda graphs are present", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", action="store_true", help="To enable fp16. Default is bf16" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def collate_wrapper(calib_dataloader_collate, *args, **kwargs): |
|
|
|
|
|
return torch.stack(calib_dataloader_collate(*args, **kwargs)["inputs"], dim=0).to( |
|
|
dtype=torch.float |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
if len(args.shape) == 1: |
|
|
input_shape = (64, 3, args.shape[0], args.shape[0]) |
|
|
elif len(args.shape) == 2: |
|
|
input_shape = ( |
|
|
64, |
|
|
3, |
|
|
) + tuple(args.shape) |
|
|
else: |
|
|
raise ValueError("invalid input shape") |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
checkpoint_basename = Path(args.checkpoint).stem |
|
|
|
|
|
model = FeatureExtractor(model=args.config, pretrained=args.checkpoint).model |
|
|
model.backbone.out_type = ( |
|
|
"featmap" |
|
|
) |
|
|
model.eval() |
|
|
max_batch_size = args.max_batch_size |
|
|
input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:]) |
|
|
|
|
|
mm_inputs = _demo_mm_inputs(input_shape) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model.cuda() |
|
|
mm_inputs["imgs"] = mm_inputs["imgs"].cuda() |
|
|
|
|
|
_benchmark(model, mm_inputs, "Original") |
|
|
|
|
|
graphs, graph_counts, break_reasons = explain_model(model, mm_inputs) |
|
|
|
|
|
if args.explain_verbose: |
|
|
print(f"Graphs: {graphs}") |
|
|
print(f"Graph Counts: {graph_counts}") |
|
|
print(f"Reasons: {break_reasons}") |
|
|
|
|
|
if not args.force_compile and graph_counts > 1: |
|
|
print(f"Graphs are not fusable. Expected 1 graph. Found {graph_counts}") |
|
|
return |
|
|
|
|
|
dtype = torch.bfloat16 if not args.fp16 else torch.half |
|
|
|
|
|
model.to(dtype) |
|
|
mm_inputs["imgs"] = mm_inputs["imgs"].to(dtype) |
|
|
save_path = os.path.join( |
|
|
args.output_dir, |
|
|
f"{checkpoint_basename}_{'float16' if dtype==torch.float16 else 'bfloat16'}.pt2", |
|
|
) |
|
|
compile_model( |
|
|
model, mm_inputs, save_path, max_batch_size=max_batch_size, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|