| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import types |
| from pathlib import Path |
|
|
| import tensorrt as trt |
| import torch |
| from cache_diffusion.cachify import CACHED_PIPE, get_model |
| from cuda import cudart |
| from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| from pipe.config import ONNX_CONFIG |
| from pipe.models.sd3 import sd3_forward |
| from pipe.models.sdxl import ( |
| cachecrossattnupblock2d_forward, |
| cacheunet_forward, |
| cacheupblock2d_forward, |
| ) |
| from polygraphy.backend.trt import ( |
| CreateConfig, |
| Profile, |
| engine_from_network, |
| network_from_onnx_path, |
| save_engine, |
| ) |
| from torch.onnx import export as onnx_export |
|
|
| from .utils import Engine |
|
|
|
|
| def replace_new_forward(backbone): |
| if backbone.__class__ == UNet2DConditionModel: |
| backbone.forward = types.MethodType(cacheunet_forward, backbone) |
| for upsample_block in backbone.up_blocks: |
| if ( |
| hasattr(upsample_block, "has_cross_attention") |
| and upsample_block.has_cross_attention |
| ): |
| upsample_block.forward = types.MethodType( |
| cachecrossattnupblock2d_forward, upsample_block |
| ) |
| else: |
| upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block) |
| elif backbone.__class__ == SD3Transformer2DModel: |
| backbone.forward = types.MethodType(sd3_forward, backbone) |
|
|
|
|
| def get_input_info(dummy_dict, info: str = None, batch_size: int = 1): |
| return_val = [] if info == "profile_shapes" or info == "input_names" else {} |
|
|
| def collect_leaf_keys(d): |
| for key, value in d.items(): |
| if isinstance(value, dict): |
| collect_leaf_keys(value) |
| else: |
| value = (value[0] * batch_size,) + value[1:] |
| if info == "profile_shapes": |
| return_val.append((key, value)) |
| elif info == "profile_shapes_dict": |
| return_val[key] = value |
| elif info == "dummy_input": |
| return_val[key] = torch.ones(value).half().cuda() |
| elif info == "input_names": |
| return_val.append(key) |
|
|
| collect_leaf_keys(dummy_dict) |
| return return_val |
|
|
|
|
| def complie2trt(cls, onnx_path: Path, engine_path: Path, batch_size: int = 1): |
| subdirs = [f for f in onnx_path.iterdir() if f.is_dir()] |
| for subdir in subdirs: |
| if subdir.name not in ONNX_CONFIG[cls].keys(): |
| continue |
| model_path = subdir / "model.onnx" |
| plan_path = engine_path / f"{subdir.name}.plan" |
| if not plan_path.exists(): |
| print(f"Building {str(model_path)}") |
| build_profile = Profile() |
| profile_shapes = get_input_info( |
| ONNX_CONFIG[cls][subdir.name]["dummy_input"], "profile_shapes", batch_size |
| ) |
| for input_name, input_shape in profile_shapes: |
| min_input_shape = (2,) + input_shape[1:] |
| build_profile.add(input_name, min_input_shape, input_shape, input_shape) |
| block_network = network_from_onnx_path( |
| str(model_path), flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM], strongly_typed=True |
| ) |
| build_config = CreateConfig( |
| builder_optimization_level=6, |
| tf32=True, |
| |
| profiles=[build_profile], |
| ) |
| engine = engine_from_network( |
| block_network, |
| config=build_config, |
| ) |
| save_engine(engine, path=plan_path) |
| else: |
| print(f"{str(model_path)} already exists!") |
|
|
|
|
| def get_total_device_memory(backbone): |
| max_device_memory = 0 |
| for _, engine in backbone.engines.items(): |
| max_device_memory = max(max_device_memory, engine.engine.device_memory_size) |
| return max_device_memory |
|
|
|
|
| def load_engines(backbone, engine_path: Path, batch_size: int = 1): |
| backbone.engines = {} |
| for f in engine_path.iterdir(): |
| if f.is_file(): |
| eng = Engine() |
| eng.load(str(f)) |
| backbone.engines[f"{f.stem}"] = eng |
| _, shared_device_memory = cudart.cudaMalloc(get_total_device_memory(backbone)) |
| for engine in backbone.engines.values(): |
| engine.activate(shared_device_memory) |
| backbone.cuda_stream = cudart.cudaStreamCreate()[1] |
| for block_name in backbone.engines.keys(): |
| backbone.engines[block_name].allocate_buffers( |
| shape_dict=get_input_info( |
| ONNX_CONFIG[backbone.__class__][block_name]["dummy_input"], |
| "profile_shapes_dict", |
| batch_size, |
| ), |
| device=backbone.device, |
| batch_size=batch_size, |
| ) |
| |
|
|
|
|
| def export_onnx(backbone, onnx_path: Path): |
| for name, module in backbone.named_modules(): |
| if isinstance(module, CACHED_PIPE[backbone.__class__]): |
| _onnx_dir = onnx_path.joinpath(f"{name}") |
| _onnx_file = _onnx_dir.joinpath("model.onnx") |
| if not _onnx_file.exists(): |
| _onnx_dir.mkdir(parents=True, exist_ok=True) |
| dummy_input = get_input_info( |
| ONNX_CONFIG[backbone.__class__][f"{name}"]["dummy_input"], "dummy_input" |
| ) |
| input_names = get_input_info( |
| ONNX_CONFIG[backbone.__class__][f"{name}"]["dummy_input"], "input_names" |
| ) |
| output_names = ONNX_CONFIG[backbone.__class__][f"{name}"]["output_names"] |
| onnx_export( |
| module, |
| args=dummy_input, |
| f=_onnx_file.as_posix(), |
| input_names=input_names, |
| output_names=output_names, |
| dynamic_axes=ONNX_CONFIG[backbone.__class__][f"{name}"]["dynamic_axes"], |
| do_constant_folding=True, |
| opset_version=17, |
| ) |
| else: |
| print(f"{str(_onnx_file)} alread exists!") |
|
|
|
|
| def warm_up(backbone, batch_size: int = 1): |
| print("Warming-up TensorRT engines...") |
| for name, engine in backbone.engines.items(): |
| dummy_input = get_input_info( |
| ONNX_CONFIG[backbone.__class__][name]["dummy_input"], "dummy_input", batch_size |
| ) |
| _ = engine(dummy_input, backbone.cuda_stream) |
|
|
|
|
| def teardown(pipe): |
| backbone = get_model(pipe) |
| for engine in backbone.engines.values(): |
| del engine |
|
|
| cudart.cudaStreamDestroy(backbone.cuda_stream) |
| del backbone.cuda_stream |
|
|
|
|
| def compile(pipe, onnx_path: Path, engine_path: Path, batch_size: int = 1): |
| backbone = get_model(pipe) |
| onnx_path.mkdir(parents=True, exist_ok=True) |
| engine_path.mkdir(parents=True, exist_ok=True) |
|
|
| replace_new_forward(backbone) |
| export_onnx(backbone, onnx_path) |
| complie2trt(backbone.__class__, onnx_path, engine_path, batch_size) |
| load_engines(backbone, engine_path, batch_size) |
| warm_up(backbone, batch_size) |
| backbone.use_trt_infer = True |
|
|