| | import fnmatch |
| | from contextlib import contextmanager |
| |
|
| | from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock |
| | from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel |
| | from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
| | from diffusers.models.unets.unet_2d_blocks import ( |
| | CrossAttnDownBlock2D, |
| | CrossAttnUpBlock2D, |
| | DownBlock2D, |
| | UNetMidBlock2DCrossAttn, |
| | UpBlock2D, |
| | ) |
| | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
| | from diffusers.models.unets.unet_3d_blocks import ( |
| | CrossAttnDownBlockSpatioTemporal, |
| | CrossAttnUpBlockSpatioTemporal, |
| | DownBlockSpatioTemporal, |
| | UNetMidBlockSpatioTemporal, |
| | UpBlockSpatioTemporal, |
| | ) |
| | from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel |
| |
|
| | from .module import CachedModule |
| | from .utils import replace_module |
| |
|
| | CACHED_PIPE = { |
| | UNet2DConditionModel: ( |
| | DownBlock2D, |
| | CrossAttnDownBlock2D, |
| | UNetMidBlock2DCrossAttn, |
| | CrossAttnUpBlock2D, |
| | UpBlock2D, |
| | ), |
| | PixArtTransformer2DModel: (BasicTransformerBlock), |
| | UNetSpatioTemporalConditionModel: ( |
| | CrossAttnDownBlockSpatioTemporal, |
| | DownBlockSpatioTemporal, |
| | UpBlockSpatioTemporal, |
| | CrossAttnUpBlockSpatioTemporal, |
| | UNetMidBlockSpatioTemporal, |
| | ), |
| | SD3Transformer2DModel: (JointTransformerBlock), |
| | } |
| |
|
| |
|
| | def _apply_to_modules(model, action, modules=None, config_list=None): |
| | if hasattr(model, "use_trt_infer") and model.use_trt_infer: |
| | for key, module in model.engines.items(): |
| | if isinstance(module, CachedModule): |
| | action(module) |
| | elif config_list: |
| | for config in config_list: |
| | if _pass(key, config["wildcard_or_filter_func"]): |
| | model.engines[key] = CachedModule(module, config["select_cache_step_func"]) |
| | else: |
| | for name, module in model.named_modules(): |
| | if isinstance(module, CachedModule): |
| | action(module) |
| | elif modules and config_list: |
| | for config in config_list: |
| | if _pass(name, config["wildcard_or_filter_func"]) and isinstance( |
| | module, modules |
| | ): |
| | replace_module( |
| | model, |
| | name, |
| | CachedModule(module, config["select_cache_step_func"]), |
| | ) |
| |
|
| |
|
| | def cachify(model, config_list, modules): |
| | def cache_action(module): |
| | pass |
| |
|
| | _apply_to_modules(model, cache_action, modules, config_list) |
| |
|
| |
|
| | def disable(pipe): |
| | model = get_model(pipe) |
| | _apply_to_modules(model, lambda module: module.disable_cache()) |
| |
|
| |
|
| | def enable(pipe): |
| | model = get_model(pipe) |
| | _apply_to_modules(model, lambda module: module.enable_cache()) |
| |
|
| |
|
| | def reset_status(pipe): |
| | model = get_model(pipe) |
| | _apply_to_modules(model, lambda module: setattr(module, "cur_step", 0)) |
| |
|
| |
|
| | def _pass(name, wildcard_or_filter_func): |
| | if isinstance(wildcard_or_filter_func, str): |
| | return fnmatch.fnmatch(name, wildcard_or_filter_func) |
| | elif callable(wildcard_or_filter_func): |
| | return wildcard_or_filter_func(name) |
| | else: |
| | raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") |
| |
|
| |
|
| | def get_model(pipe): |
| | if hasattr(pipe, "unet"): |
| | return pipe.unet |
| | elif hasattr(pipe, "transformer"): |
| | return pipe.transformer |
| | else: |
| | raise KeyError |
| |
|
| |
|
| | @contextmanager |
| | def infer(pipe): |
| | try: |
| | yield pipe |
| | finally: |
| | reset_status(pipe) |
| |
|
| |
|
| | def prepare(pipe, config_list): |
| | model = get_model(pipe) |
| | assert model.__class__ in CACHED_PIPE.keys(), f"{model.__class__} is not supported!" |
| | cachify(model, config_list, CACHED_PIPE[model.__class__]) |
| |
|