Spaces:
Running on Zero
Running on Zero
File size: 5,554 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """DeepCache implementation for LightDiffusion-Next.
Based on:
- https://github.com/horseee/DeepCache
- https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86
DeepCache accelerates diffusion models by reusing high-level features
while updating low-level features in a cheap way.
"""
import torch
import logging
class ApplyDeepCacheOnModel:
"""Apply DeepCache optimization to a model.
DeepCache works by caching intermediate features in the U-Net architecture
and reusing them for certain steps, significantly reducing computation.
"""
def patch(
self,
model,
object_to_patch="diffusion_model",
cache_interval=3,
cache_depth=2,
start_step=0,
end_step=1000,
):
"""Patch the model with DeepCache optimization.
Args:
model: The model to patch (should be a ModelPatcher or tuple containing one)
object_to_patch: Name of the model object to patch (default: "diffusion_model")
cache_interval: Interval for cache updates (higher = more speedup, lower quality)
cache_depth: Depth of caching in U-Net blocks (0-12, higher = more aggressive)
start_step: Start applying DeepCache at this timestep (0-1000)
end_step: Stop applying DeepCache at this timestep (0-1000)
Returns:
Tuple containing the patched model
"""
logger = logging.getLogger(__name__)
# Handle both raw model and tuple input
if isinstance(model, (tuple, list)):
model = model[0]
# Clone the model to avoid modifying the original
new_model = model.clone()
# State variables for cache management
current_t = -1
current_step = -1
cached_output = None
def apply_model_deepcache(model_function, kwargs):
"""Wrapper function that applies DeepCache logic to model forward pass.
DeepCache works by simply reusing the output from previous steps instead of
recomputing the full U-Net forward pass. This is much simpler and more robust
than trying to manually execute partial U-Net blocks.
"""
nonlocal current_t, current_step, cached_output
try:
# Extract inputs from kwargs
xa = kwargs["input"]
t = kwargs["timestep"]
c_dict = kwargs.get("c", {})
# Get the diffusion model (UNet) for validation
try:
unet = new_model.get_model_object(object_to_patch)
except Exception:
# If we can't get the object, just run normally
return model_function(xa, t, **c_dict)
# Check if this is a UNet-based model (SD1.5, SD2.1, SDXL, etc.)
if not hasattr(unet, "input_blocks") or not hasattr(unet, "output_blocks"):
# Not a U-Net architecture, skip DeepCache
return model_function(xa, t, **c_dict)
# Get current timestep value
current_t_value = t[0].item()
# Reset step counter if timestep increased (new batch/generation)
if current_t_value > current_t:
current_step = -1
cached_output = None
current_t = current_t_value
# Determine if we should apply caching at this timestep
# Note: t goes from 999 -> 0 during generation
apply = (1000 - end_step) <= current_t <= (1000 - start_step)
if apply:
current_step += 1
else:
current_step = -1
cached_output = None
# Determine if this is a cache update step or cache reuse step
is_cache_step = (current_step % cache_interval == 0) if apply else True
# If not applying DeepCache or it's a cache update step, run full model
if not apply or is_cache_step:
result = model_function(xa, t, **c_dict)
# Store the output for future reuse
if apply:
cached_output = result.clone() if hasattr(result, 'clone') else result
return result
# Cache reuse step - return cached output instead of recomputing
if cached_output is not None:
# DeepCache speedup: reuse previous output
return cached_output
else:
# First non-cache step but no cache yet - run normally and cache
result = model_function(xa, t, **c_dict)
cached_output = result.clone() if hasattr(result, 'clone') else result
return result
except Exception as e:
# Any error - run normal forward and reset cache
logger.error(f"DeepCache wrapper error: {e}")
cached_output = None
return model_function(xa, t, **c_dict)
# Apply the wrapper
new_model.set_model_unet_function_wrapper(apply_model_deepcache)
return (new_model,)
|