| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import types |
| from pathlib import Path |
|
|
| import tensorrt as trt |
| import torch |
| from cache_diffusion.cachify import CACHED_PIPE, get_model |
| from cuda import cudart |
| from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| from trt_pipeline.config import ONNX_CONFIG |
| from trt_pipeline.models.sd3 import sd3_forward |
| from trt_pipeline.models.sdxl import ( |
| cachecrossattnupblock2d_forward, |
| cacheunet_forward, |
| cacheupblock2d_forward, |
| ) |
| from polygraphy.backend.trt import ( |
| CreateConfig, |
| Profile, |
| engine_from_network, |
| network_from_onnx_path, |
| save_engine, |
| ) |
| from torch.onnx import export as onnx_export |
|
|
| from .utils import Engine |
|
|
|
|
| def replace_new_forward(backbone): |
| if backbone.__class__ == UNet2DConditionModel: |
| backbone.forward = types.MethodType(cacheunet_forward, backbone) |
| for upsample_block in backbone.up_blocks: |
| if ( |
| hasattr(upsample_block, "has_cross_attention") |
| and upsample_block.has_cross_attention |
| ): |
| upsample_block.forward = types.MethodType( |
| cachecrossattnupblock2d_forward, upsample_block |
| ) |
| else: |
| upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block) |
| elif backbone.__class__ == SD3Transformer2DModel: |
| backbone.forward = types.MethodType(sd3_forward, backbone) |
|
|
|
|
| def get_input_info(dummy_dict, info: str = None, batch_size: int = 1): |
| return_val = [] if info == "profile_shapes" or info == "input_names" else {} |
|
|
| def collect_leaf_keys(d): |
| for key, value in d.items(): |
| if isinstance(value, dict): |
| collect_leaf_keys(value) |
| else: |
| value = (value[0] * batch_size,) + value[1:] |
| if info == "profile_shapes": |
| return_val.append((key, value)) |
| elif info == "profile_shapes_dict": |
| return_val[key] = value |
| elif info == "dummy_input": |
| return_val[key] = torch.ones(value).half().cuda() |
| elif info == "input_names": |
| return_val.append(key) |
|
|
| collect_leaf_keys(dummy_dict) |
| return return_val |
|
|
|
|
| def get_total_device_memory(backbone): |
| max_device_memory = 0 |
| for _, engine in backbone.engines.items(): |
| max_device_memory = max(max_device_memory, engine.engine.device_memory_size) |
| return max_device_memory |
|
|
|
|
| def load_engines(backbone, engine_path: Path, batch_size: int = 1): |
| backbone.engines = {} |
| for f in engine_path.iterdir(): |
| if f.is_file(): |
| eng = Engine() |
| eng.load(str(f)) |
| backbone.engines[f"{f.stem}"] = eng |
| _, shared_device_memory = cudart.cudaMalloc(get_total_device_memory(backbone)) |
| for engine in backbone.engines.values(): |
| engine.activate(shared_device_memory) |
| backbone.cuda_stream = cudart.cudaStreamCreate()[1] |
| for block_name in backbone.engines.keys(): |
| backbone.engines[block_name].allocate_buffers( |
| shape_dict=get_input_info( |
| ONNX_CONFIG[backbone.__class__][block_name]["dummy_input"], |
| "profile_shapes_dict", |
| batch_size, |
| ), |
| device=backbone.device, |
| batch_size=batch_size, |
| ) |
| |
|
|
|
|
| def warm_up(backbone, batch_size: int = 1): |
| print("Warming-up TensorRT engines...") |
| for name, engine in backbone.engines.items(): |
| dummy_input = get_input_info( |
| ONNX_CONFIG[backbone.__class__][name]["dummy_input"], "dummy_input", batch_size |
| ) |
| _ = engine(dummy_input, backbone.cuda_stream) |
|
|
|
|
| def teardown(pipe): |
| backbone = get_model(pipe) |
| for engine in backbone.engines.values(): |
| del engine |
|
|
| cudart.cudaStreamDestroy(backbone.cuda_stream) |
| del backbone.cuda_stream |
|
|
|
|
| def load_unet_trt(unet, engine_path: Path, batch_size: int = 1): |
| backbone = unet |
| engine_path.mkdir(parents=True, exist_ok=True) |
| replace_new_forward(backbone) |
| load_engines(backbone, engine_path, batch_size) |
| warm_up(backbone, batch_size) |
| backbone.use_trt_infer = True |
|
|