File size: 8,720 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""Protocol definitions for LightDiffusion-Next.

This module defines the contracts (interfaces) that all components must follow.
Using Protocol allows for structural subtyping without requiring explicit inheritance.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

import torch

if TYPE_CHECKING:
    from src.Core.Context import Context


# ============================================================================
# MODEL PROTOCOLS
# ============================================================================

@dataclass
class ModelCapabilities:
    """Describes what a model implementation can do."""
    min_resolution: int = 256
    max_resolution: int = 2048
    preferred_resolution: int = 512
    requires_resolution_multiple: int = 64
    supports_hires_fix: bool = True
    supports_img2img: bool = True
    supports_inpainting: bool = False
    supports_controlnet: bool = False
    supports_stable_fast: bool = True
    supports_deepcache: bool = True
    supports_tome: bool = True
    uses_dual_clip: bool = False
    requires_size_conditioning: bool = False
    
    def validate_resolution(self, width: int, height: int) -> tuple[int, int]:
        """Clamp and round resolution to model requirements."""
        width = max(self.min_resolution, min(width, self.max_resolution))
        height = max(self.min_resolution, min(height, self.max_resolution))
        width = (width // self.requires_resolution_multiple) * self.requires_resolution_multiple
        height = (height // self.requires_resolution_multiple) * self.requires_resolution_multiple
        return width, height


@runtime_checkable
class ModelProtocol(Protocol):
    """Protocol defining the contract for all model implementations."""
    
    model: Any
    clip: Any
    vae: Any
    model_path: str
    
    @property
    def capabilities(self) -> ModelCapabilities:
        """Return model capabilities."""
        ...
    
    @property
    def is_loaded(self) -> bool:
        """Check if model is loaded."""
        ...
    
    def load(self, model_path: str = None) -> "ModelProtocol":
        """Load model from disk."""
        ...
    
    def encode_prompt(
        self,
        prompt: str | list[str],
        negative_prompt: str | list[str] = "",
        clip_skip: int = -2,
    ) -> tuple[Any, Any]:
        """Encode prompts to conditioning."""
        ...
    
    def generate(
        self,
        ctx: "Context",
        positive: Any,
        negative: Any,
    ) -> dict:
        """Generate latents."""
        ...
    
    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        """Decode latents to images."""
        ...
    
    def apply_lora(
        self,
        lora_name: str,
        strength_model: float = 1.0,
        strength_clip: float = 1.0,
    ) -> "ModelProtocol":
        """Apply LoRA weights."""
        ...
    
    def apply_stable_fast(self, enable_cuda_graph: bool = True) -> "ModelProtocol":
        """Apply StableFast optimization."""
        ...
    
    def apply_deepcache(
        self,
        cache_interval: int = 3,
        cache_depth: int = 2,
        start_step: int = 0,
        end_step: int = 1000,
    ) -> "ModelProtocol":
        """Apply DeepCache optimization."""
        ...
    
    def unload(self) -> None:
        """Release model resources."""
        ...


# ============================================================================
# PROCESSOR PROTOCOLS
# ============================================================================

@runtime_checkable
class ProcessorProtocol(Protocol):
    """Protocol for pipeline processors (plugins).
    
    Processors are stateless components that can optionally modify
    the pipeline context based on feature flags.
    """
    
    @staticmethod
    def is_enabled(ctx: "Context") -> bool:
        """Check if this processor should run for given context."""
        ...
    
    @staticmethod
    def process(
        ctx: "Context",
        model: ModelProtocol,
        **kwargs
    ) -> "Context":
        """Process the context, potentially modifying latents/images.
        
        Args:
            ctx: Pipeline context (may be modified)
            model: Loaded model for any re-sampling needed
            **kwargs: Processor-specific arguments
            
        Returns:
            Modified context
        """
        ...


class BaseProcessor(ABC):
    """Abstract base class for processors providing common functionality."""
    
    @staticmethod
    @abstractmethod
    def is_enabled(ctx: "Context") -> bool:
        """Check if this processor should run."""
        pass
    
    @staticmethod
    @abstractmethod
    def process(ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
        """Process the context."""
        pass
    
    @classmethod
    def run_if_enabled(cls, ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
        """Convenience method to conditionally run processor."""
        if cls.is_enabled(ctx):
            return cls.process(ctx, model, **kwargs)
        return ctx


# ============================================================================
# SAMPLER PROTOCOLS
# ============================================================================

@runtime_checkable
class SamplerProtocol(Protocol):
    """Protocol for diffusion samplers."""
    
    def sample(
        self,
        model: Any,
        x: torch.Tensor,
        sigmas: torch.Tensor,
        extra_args: dict = None,
        callback: Any = None,
        disable: bool = None,
        **kwargs,
    ) -> torch.Tensor:
        """Run the sampling loop.
        
        Args:
            model: The denoising model
            x: Initial noisy latents
            sigmas: Noise schedule
            extra_args: Model-specific arguments
            callback: Progress callback
            disable: Disable progress bar
            **kwargs: Sampler-specific options
            
        Returns:
            Denoised latents
        """
        ...


# ============================================================================
# CFG SCHEDULER PROTOCOLS
# ============================================================================

@runtime_checkable
class CFGSchedulerProtocol(Protocol):
    """Protocol for CFG scheduling strategies."""
    
    def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
        """Get CFG scale for a given step.
        
        Args:
            step: Current step (0-indexed)
            total_steps: Total number of steps
            base_cfg: Base CFG value
            
        Returns:
            CFG scale to use for this step
        """
        ...


class ConstantCFGScheduler:
    """Default CFG scheduler - constant value throughout."""
    
    def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
        return base_cfg


class CFGFreeScheduler:
    """CFG-free sampling - drops to CFG=1 after a percentage of steps."""
    
    def __init__(self, start_percent: float = 70.0):
        self.start_percent = start_percent
    
    def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
        progress = (step / max(1, total_steps - 1)) * 100
        if progress >= self.start_percent:
            return 1.0
        return base_cfg


class LinearDecayCFGScheduler:
    """Linearly decay CFG from base to target."""
    
    def __init__(self, target_cfg: float = 1.0):
        self.target_cfg = target_cfg
    
    def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
        progress = step / max(1, total_steps - 1)
        return base_cfg + (self.target_cfg - base_cfg) * progress


# ============================================================================
# PIPELINE PROTOCOL
# ============================================================================

@runtime_checkable
class PipelineProtocol(Protocol):
    """Protocol for the main generation pipeline."""
    
    def run(self, ctx: "Context") -> "Context":
        """Execute the full generation pipeline.
        
        Args:
            ctx: Configured context with all parameters
            
        Returns:
            Context with generated images in current_image
        """
        ...


# ============================================================================
# TYPE ALIASES
# ============================================================================

# Conditioning tuple type (used throughout the codebase)
Conditioning = list[tuple[torch.Tensor, dict[str, Any]]]

# Latent dict type
LatentDict = dict[str, torch.Tensor]