File size: 5,760 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""StableFast optimization processor for LightDiffusion-Next.

Applies torch.compile and CUDA graph optimizations to models.
"""

import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from src.Core.Context import Context
    from src.Core.AbstractModel import AbstractModel


class StableFastProcessor:
    """StableFast model optimization processor.
    
    Wraps src/StableFast/ as a standardized processor for model optimization.
    This is typically applied during model loading, not during generation.
    """
    
    @classmethod
    def is_enabled(cls, ctx: "Context") -> bool:
        """Check if StableFast should be applied."""
        return getattr(ctx.generation, "stable_fast", False)
    
    @classmethod
    def is_available(cls) -> bool:
        """Check if StableFast is available in the environment."""
        try:
            from src.StableFast import StableFast
            return True
        except ImportError:
            return False
    
    @classmethod
    def apply(
        cls,
        model: "AbstractModel",
        enable_cuda_graph: bool = True,
    ) -> "AbstractModel":
        """Apply StableFast optimization to a model.
        
        Args:
            model: Model to optimize
            enable_cuda_graph: Whether to enable CUDA graphs
            
        Returns:
            Optimized model (same instance, modified in place)
        """
        logger = logging.getLogger(__name__)
        
        if not model.capabilities.supports_stable_fast:
            logger.info("Model does not support StableFast, skipping")
            return model
        
        try:
            from src.StableFast import StableFast
            
            applier = StableFast.ApplyStableFastUnet()
            result = applier.apply_stable_fast(
                enable_cuda_graph=enable_cuda_graph,
                model=model.model,
            )
            model.model = result[0]
            
            logger.info("StableFast optimization applied")
            
        except Exception as e:
            logger.warning(f"StableFast optimization failed: {e}")
        
        return model
    
    @classmethod
    def process(
        cls,
        ctx: "Context",
        model: "AbstractModel",
        enable_cuda_graph: bool = True,
        **kwargs,
    ) -> "Context":
        """Process context, applying StableFast to the model.
        
        Note: This modifies the model in place.
        
        Args:
            ctx: Pipeline context
            model: Model to optimize
            enable_cuda_graph: Whether to enable CUDA graphs
            **kwargs: Additional parameters
            
        Returns:
            Unchanged context (model is modified in place)
        """
        if cls.is_enabled(ctx):
            cls.apply(model, enable_cuda_graph)
        
        return ctx


class DeepCacheProcessor:
    """DeepCache optimization processor.
    
    Enables feature caching in the U-Net for faster inference.
    """
    
    @classmethod
    def is_enabled(cls, ctx: "Context") -> bool:
        """Check if DeepCache should be applied."""
        return getattr(ctx.sampling, "deepcache_enabled", False)
    
    @classmethod
    def apply(
        cls,
        model: "AbstractModel",
        cache_interval: int = 3,
        cache_depth: int = 2,
        start_step: int = 0,
        end_step: int = 1000,
    ) -> "AbstractModel":
        """Apply DeepCache optimization to a model.
        
        Args:
            model: Model to optimize
            cache_interval: Steps between cache updates
            cache_depth: U-Net depth for caching
            start_step: Start applying at this step
            end_step: Stop applying at this step
            
        Returns:
            Optimized model
        """
        logger = logging.getLogger(__name__)
        
        if not model.capabilities.supports_deepcache:
            logger.info("Model does not support DeepCache, skipping")
            return model
        
        try:
            from src.WaveSpeed import deepcache_nodes
            
            deepcache = deepcache_nodes.ApplyDeepCacheOnModel()
            result = deepcache.patch(
                model=(model.model,),
                object_to_patch="diffusion_model",
                cache_interval=cache_interval,
                cache_depth=cache_depth,
                start_step=start_step,
                end_step=end_step,
            )
            
            if isinstance(result, tuple) and len(result) > 0:
                model.model = result[0]
            
            logger.info(f"DeepCache applied (interval={cache_interval}, depth={cache_depth})")
            
        except Exception as e:
            logger.warning(f"DeepCache optimization failed: {e}")
        
        return model
    
    @classmethod
    def process(
        cls,
        ctx: "Context",
        model: "AbstractModel",
        **kwargs,
    ) -> "Context":
        """Process context, applying DeepCache to the model.
        
        Args:
            ctx: Pipeline context with deepcache settings
            model: Model to optimize
            **kwargs: Additional parameters
            
        Returns:
            Unchanged context (model is modified in place)
        """
        if not cls.is_enabled(ctx):
            return ctx
        
        sampling = ctx.sampling
        cls.apply(
            model,
            cache_interval=getattr(sampling, "deepcache_interval", 3),
            cache_depth=getattr(sampling, "deepcache_depth", 2),
            start_step=getattr(sampling, "deepcache_start_step", 0),
            end_step=getattr(sampling, "deepcache_end_step", 1000),
        )
        
        return ctx