|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import types |
|
|
from pathlib import Path |
|
|
|
|
|
import tensorrt as trt |
|
|
import torch |
|
|
from cache_diffusion.cachify import CACHED_PIPE, get_model |
|
|
from cuda import cudart |
|
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
from trt_pipeline.config import ONNX_CONFIG |
|
|
from trt_pipeline.models.sd3 import sd3_forward |
|
|
from trt_pipeline.models.sdxl import ( |
|
|
cachecrossattnupblock2d_forward, |
|
|
cacheunet_forward, |
|
|
cacheupblock2d_forward, |
|
|
) |
|
|
from polygraphy.backend.trt import ( |
|
|
CreateConfig, |
|
|
Profile, |
|
|
engine_from_network, |
|
|
network_from_onnx_path, |
|
|
save_engine, |
|
|
) |
|
|
from torch.onnx import export as onnx_export |
|
|
|
|
|
from .utils import Engine |
|
|
|
|
|
|
|
|
def replace_new_forward(backbone): |
|
|
if backbone.__class__ == UNet2DConditionModel: |
|
|
backbone.forward = types.MethodType(cacheunet_forward, backbone) |
|
|
for upsample_block in backbone.up_blocks: |
|
|
if ( |
|
|
hasattr(upsample_block, "has_cross_attention") |
|
|
and upsample_block.has_cross_attention |
|
|
): |
|
|
upsample_block.forward = types.MethodType( |
|
|
cachecrossattnupblock2d_forward, upsample_block |
|
|
) |
|
|
else: |
|
|
upsample_block.forward = types.MethodType(cacheupblock2d_forward, upsample_block) |
|
|
elif backbone.__class__ == SD3Transformer2DModel: |
|
|
backbone.forward = types.MethodType(sd3_forward, backbone) |
|
|
|
|
|
|
|
|
def get_input_info(dummy_dict, info: str = None, batch_size: int = 1): |
|
|
return_val = [] if info == "profile_shapes" or info == "input_names" else {} |
|
|
|
|
|
def collect_leaf_keys(d): |
|
|
for key, value in d.items(): |
|
|
if isinstance(value, dict): |
|
|
collect_leaf_keys(value) |
|
|
else: |
|
|
value = (value[0] * batch_size,) + value[1:] |
|
|
if info == "profile_shapes": |
|
|
return_val.append((key, value)) |
|
|
elif info == "profile_shapes_dict": |
|
|
return_val[key] = value |
|
|
elif info == "dummy_input": |
|
|
return_val[key] = torch.ones(value).half().cuda() |
|
|
elif info == "input_names": |
|
|
return_val.append(key) |
|
|
|
|
|
collect_leaf_keys(dummy_dict) |
|
|
return return_val |
|
|
|
|
|
|
|
|
def get_total_device_memory(backbone): |
|
|
max_device_memory = 0 |
|
|
for _, engine in backbone.engines.items(): |
|
|
max_device_memory = max(max_device_memory, engine.engine.device_memory_size) |
|
|
return max_device_memory |
|
|
|
|
|
|
|
|
def load_engines(backbone, engine_path: Path, batch_size: int = 1): |
|
|
backbone.engines = {} |
|
|
for f in engine_path.iterdir(): |
|
|
if f.is_file(): |
|
|
eng = Engine() |
|
|
eng.load(str(f)) |
|
|
backbone.engines[f"{f.stem}"] = eng |
|
|
_, shared_device_memory = cudart.cudaMalloc(get_total_device_memory(backbone)) |
|
|
for engine in backbone.engines.values(): |
|
|
engine.activate(shared_device_memory) |
|
|
backbone.cuda_stream = cudart.cudaStreamCreate()[1] |
|
|
for block_name in backbone.engines.keys(): |
|
|
backbone.engines[block_name].allocate_buffers( |
|
|
shape_dict=get_input_info( |
|
|
ONNX_CONFIG[backbone.__class__][block_name]["dummy_input"], |
|
|
"profile_shapes_dict", |
|
|
batch_size, |
|
|
), |
|
|
device=backbone.device, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def warm_up(backbone, batch_size: int = 1): |
|
|
print("Warming-up TensorRT engines...") |
|
|
for name, engine in backbone.engines.items(): |
|
|
dummy_input = get_input_info( |
|
|
ONNX_CONFIG[backbone.__class__][name]["dummy_input"], "dummy_input", batch_size |
|
|
) |
|
|
_ = engine(dummy_input, backbone.cuda_stream) |
|
|
|
|
|
|
|
|
def teardown(pipe): |
|
|
backbone = get_model(pipe) |
|
|
for engine in backbone.engines.values(): |
|
|
del engine |
|
|
|
|
|
cudart.cudaStreamDestroy(backbone.cuda_stream) |
|
|
del backbone.cuda_stream |
|
|
|
|
|
|
|
|
def load_unet_trt(unet, engine_path: Path, batch_size: int = 1): |
|
|
backbone = unet |
|
|
engine_path.mkdir(parents=True, exist_ok=True) |
|
|
replace_new_forward(backbone) |
|
|
load_engines(backbone, engine_path, batch_size) |
|
|
warm_up(backbone, batch_size) |
|
|
backbone.use_trt_infer = True |
|
|
|