File size: 5,554 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""DeepCache implementation for LightDiffusion-Next.

Based on:
- https://github.com/horseee/DeepCache
- https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86

DeepCache accelerates diffusion models by reusing high-level features
while updating low-level features in a cheap way.
"""

import torch
import logging


class ApplyDeepCacheOnModel:
    """Apply DeepCache optimization to a model.
    
    DeepCache works by caching intermediate features in the U-Net architecture
    and reusing them for certain steps, significantly reducing computation.
    """

    def patch(
        self,
        model,
        object_to_patch="diffusion_model",
        cache_interval=3,
        cache_depth=2,
        start_step=0,
        end_step=1000,
    ):
        """Patch the model with DeepCache optimization.
        
        Args:
            model: The model to patch (should be a ModelPatcher or tuple containing one)
            object_to_patch: Name of the model object to patch (default: "diffusion_model")
            cache_interval: Interval for cache updates (higher = more speedup, lower quality)
            cache_depth: Depth of caching in U-Net blocks (0-12, higher = more aggressive)
            start_step: Start applying DeepCache at this timestep (0-1000)
            end_step: Stop applying DeepCache at this timestep (0-1000)
        
        Returns:
            Tuple containing the patched model
        """
        logger = logging.getLogger(__name__)
        
        # Handle both raw model and tuple input
        if isinstance(model, (tuple, list)):
            model = model[0]
        
        # Clone the model to avoid modifying the original
        new_model = model.clone()
        
        # State variables for cache management
        current_t = -1
        current_step = -1
        cached_output = None
        
        def apply_model_deepcache(model_function, kwargs):
            """Wrapper function that applies DeepCache logic to model forward pass.
            
            DeepCache works by simply reusing the output from previous steps instead of
            recomputing the full U-Net forward pass. This is much simpler and more robust
            than trying to manually execute partial U-Net blocks.
            """
            nonlocal current_t, current_step, cached_output
            
            try:
                # Extract inputs from kwargs
                xa = kwargs["input"]
                t = kwargs["timestep"]
                c_dict = kwargs.get("c", {})
                
                # Get the diffusion model (UNet) for validation
                try:
                    unet = new_model.get_model_object(object_to_patch)
                except Exception:
                    # If we can't get the object, just run normally
                    return model_function(xa, t, **c_dict)
                
                # Check if this is a UNet-based model (SD1.5, SD2.1, SDXL, etc.)
                if not hasattr(unet, "input_blocks") or not hasattr(unet, "output_blocks"):
                    # Not a U-Net architecture, skip DeepCache
                    return model_function(xa, t, **c_dict)
                
                # Get current timestep value
                current_t_value = t[0].item()
                
                # Reset step counter if timestep increased (new batch/generation)
                if current_t_value > current_t:
                    current_step = -1
                    cached_output = None
                
                current_t = current_t_value
                
                # Determine if we should apply caching at this timestep
                # Note: t goes from 999 -> 0 during generation
                apply = (1000 - end_step) <= current_t <= (1000 - start_step)
                
                if apply:
                    current_step += 1
                else:
                    current_step = -1
                    cached_output = None
                
                # Determine if this is a cache update step or cache reuse step
                is_cache_step = (current_step % cache_interval == 0) if apply else True
                
                # If not applying DeepCache or it's a cache update step, run full model
                if not apply or is_cache_step:
                    result = model_function(xa, t, **c_dict)
                    # Store the output for future reuse
                    if apply:
                        cached_output = result.clone() if hasattr(result, 'clone') else result
                    return result
                
                # Cache reuse step - return cached output instead of recomputing
                if cached_output is not None:
                    # DeepCache speedup: reuse previous output
                    return cached_output
                else:
                    # First non-cache step but no cache yet - run normally and cache
                    result = model_function(xa, t, **c_dict)
                    cached_output = result.clone() if hasattr(result, 'clone') else result
                    return result
                
            except Exception as e:
                # Any error - run normal forward and reset cache
                logger.error(f"DeepCache wrapper error: {e}")
                cached_output = None
                return model_function(xa, t, **c_dict)
        
        # Apply the wrapper
        new_model.set_model_unet_function_wrapper(apply_model_deepcache)
        
        return (new_model,)