File size: 11,095 Bytes
ac2243f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | ## CacheDiT
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
<hfoptions id="install">
<hfoption id="PyPI">
```bash
pip3 install -U cache-dit
```
</hfoption>
<hfoption id="source">
```bash
pip3 install git+https://github.com/vipshop/cache-dit.git
```
</hfoption>
</hfoptions>
Run the command below to view supported DiT pipelines.
```python
>>> import cache_dit
>>> cache_dit.supported_pipelines()
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
```
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
## Unified Cache API
CacheDiT works by matching specific input/output patterns as shown below.

Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
```python
import cache_dit
from diffusers import DiffusionPipeline
# Can be any diffusion pipeline
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
# One-line code with default cache options.
cache_dit.enable_cache(pipe)
# Just call the pipe as normal.
output = pipe(...)
# Disable cache and run original pipe.
cache_dit.disable_cache(pipe)
```
## Automatic Block Adapter
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
```python
from cache_dit import ForwardPattern, BlockAdapter
# Use 🔥BlockAdapter with `auto` mode.
cache_dit.enable_cache(
BlockAdapter(
# Any DiffusionPipeline, Qwen-Image, etc.
pipe=pipe, auto=True,
# Check `📚Forward Pattern Matching` documentation and hack the code of
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
forward_pattern=ForwardPattern.Pattern_1,
),
)
# Or, manually setup transformer configurations.
cache_dit.enable_cache(
BlockAdapter(
pipe=pipe, # Qwen-Image, etc.
transformer=pipe.transformer,
blocks=pipe.transformer.transformer_blocks,
forward_pattern=ForwardPattern.Pattern_1,
),
)
```
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
```python
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
# single_transformer_blocks have different forward patterns.
cache_dit.enable_cache(
BlockAdapter(
pipe=pipe, # FLUX.1, etc.
transformer=pipe.transformer,
blocks=[
pipe.transformer.transformer_blocks,
pipe.transformer.single_transformer_blocks,
],
forward_pattern=[
ForwardPattern.Pattern_1,
ForwardPattern.Pattern_3,
],
),
)
```
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
## Patch Functor
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.

Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
```python
@BlockAdapterRegistry.register("HiDream")
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
from diffusers import HiDreamImageTransformer2DModel
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
return BlockAdapter(
pipe=pipe,
transformer=pipe.transformer,
blocks=[
pipe.transformer.double_stream_blocks,
pipe.transformer.single_stream_blocks,
],
forward_pattern=[
ForwardPattern.Pattern_0,
ForwardPattern.Pattern_3,
],
# NOTE: Setup your custom patch functor here.
patch_functor=HiDreamPatchFunctor(),
**kwargs,
)
```
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
```python
stats = cache_dit.summary(pipe)
```
```python
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
```
## DBCache: Dual Block Cache

DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
```python
import cache_dit
from diffusers import FluxPipeline
pipe_or_adapter = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
# Default options, F8B0, 8 warmup steps, and unlimited cached
# steps for good balance between performance and precision
cache_dit.enable_cache(pipe_or_adapter)
# Custom options, F8B8, higher precision
from cache_dit import BasicCacheConfig
cache_dit.enable_cache(
pipe_or_adapter,
cache_config=BasicCacheConfig(
max_warmup_steps=8, # steps do not cache
max_cached_steps=-1, # -1 means no limit
Fn_compute_blocks=8, # Fn, F8, etc.
Bn_compute_blocks=8, # Bn, B8, etc.
residual_diff_threshold=0.12,
),
)
```
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
## TaylorSeer Calibrator
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
```python
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
cache_dit.enable_cache(
pipe_or_adapter,
# Basic DBCache w/ FnBn configurations
cache_config=BasicCacheConfig(
max_warmup_steps=8, # steps do not cache
max_cached_steps=-1, # -1 means no limit
Fn_compute_blocks=8, # Fn, F8, etc.
Bn_compute_blocks=8, # Bn, B8, etc.
residual_diff_threshold=0.12,
),
# Then, you can use the TaylorSeer Calibrator to approximate
# the values in cached steps, taylorseer_order default is 1.
calibrator_config=TaylorSeerCalibratorConfig(
taylorseer_order=1,
),
)
```
> [!TIP]
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
## Hybrid Cache CFG
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
```python
from cache_dit import BasicCacheConfig
cache_dit.enable_cache(
pipe_or_adapter,
cache_config=BasicCacheConfig(
...,
# For example, set it as True for Wan 2.1, Qwen-Image
# and set it as False for FLUX.1, HunyuanVideo, etc.
enable_separate_cfg=True,
),
)
```
## torch.compile
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
```python
cache_dit.enable_cache(pipe)
# Compile the Transformer module
pipe.transformer = torch.compile(pipe.transformer)
```
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
```python
torch._dynamo.config.recompile_limit = 96 # default is 8
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
```
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
|