Spaces:
Running on Zero
Running on Zero
File size: 8,720 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | """Protocol definitions for LightDiffusion-Next.
This module defines the contracts (interfaces) that all components must follow.
Using Protocol allows for structural subtyping without requiring explicit inheritance.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
import torch
if TYPE_CHECKING:
from src.Core.Context import Context
# ============================================================================
# MODEL PROTOCOLS
# ============================================================================
@dataclass
class ModelCapabilities:
"""Describes what a model implementation can do."""
min_resolution: int = 256
max_resolution: int = 2048
preferred_resolution: int = 512
requires_resolution_multiple: int = 64
supports_hires_fix: bool = True
supports_img2img: bool = True
supports_inpainting: bool = False
supports_controlnet: bool = False
supports_stable_fast: bool = True
supports_deepcache: bool = True
supports_tome: bool = True
uses_dual_clip: bool = False
requires_size_conditioning: bool = False
def validate_resolution(self, width: int, height: int) -> tuple[int, int]:
"""Clamp and round resolution to model requirements."""
width = max(self.min_resolution, min(width, self.max_resolution))
height = max(self.min_resolution, min(height, self.max_resolution))
width = (width // self.requires_resolution_multiple) * self.requires_resolution_multiple
height = (height // self.requires_resolution_multiple) * self.requires_resolution_multiple
return width, height
@runtime_checkable
class ModelProtocol(Protocol):
"""Protocol defining the contract for all model implementations."""
model: Any
clip: Any
vae: Any
model_path: str
@property
def capabilities(self) -> ModelCapabilities:
"""Return model capabilities."""
...
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
def load(self, model_path: str = None) -> "ModelProtocol":
"""Load model from disk."""
...
def encode_prompt(
self,
prompt: str | list[str],
negative_prompt: str | list[str] = "",
clip_skip: int = -2,
) -> tuple[Any, Any]:
"""Encode prompts to conditioning."""
...
def generate(
self,
ctx: "Context",
positive: Any,
negative: Any,
) -> dict:
"""Generate latents."""
...
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode latents to images."""
...
def apply_lora(
self,
lora_name: str,
strength_model: float = 1.0,
strength_clip: float = 1.0,
) -> "ModelProtocol":
"""Apply LoRA weights."""
...
def apply_stable_fast(self, enable_cuda_graph: bool = True) -> "ModelProtocol":
"""Apply StableFast optimization."""
...
def apply_deepcache(
self,
cache_interval: int = 3,
cache_depth: int = 2,
start_step: int = 0,
end_step: int = 1000,
) -> "ModelProtocol":
"""Apply DeepCache optimization."""
...
def unload(self) -> None:
"""Release model resources."""
...
# ============================================================================
# PROCESSOR PROTOCOLS
# ============================================================================
@runtime_checkable
class ProcessorProtocol(Protocol):
"""Protocol for pipeline processors (plugins).
Processors are stateless components that can optionally modify
the pipeline context based on feature flags.
"""
@staticmethod
def is_enabled(ctx: "Context") -> bool:
"""Check if this processor should run for given context."""
...
@staticmethod
def process(
ctx: "Context",
model: ModelProtocol,
**kwargs
) -> "Context":
"""Process the context, potentially modifying latents/images.
Args:
ctx: Pipeline context (may be modified)
model: Loaded model for any re-sampling needed
**kwargs: Processor-specific arguments
Returns:
Modified context
"""
...
class BaseProcessor(ABC):
"""Abstract base class for processors providing common functionality."""
@staticmethod
@abstractmethod
def is_enabled(ctx: "Context") -> bool:
"""Check if this processor should run."""
pass
@staticmethod
@abstractmethod
def process(ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
"""Process the context."""
pass
@classmethod
def run_if_enabled(cls, ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
"""Convenience method to conditionally run processor."""
if cls.is_enabled(ctx):
return cls.process(ctx, model, **kwargs)
return ctx
# ============================================================================
# SAMPLER PROTOCOLS
# ============================================================================
@runtime_checkable
class SamplerProtocol(Protocol):
"""Protocol for diffusion samplers."""
def sample(
self,
model: Any,
x: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict = None,
callback: Any = None,
disable: bool = None,
**kwargs,
) -> torch.Tensor:
"""Run the sampling loop.
Args:
model: The denoising model
x: Initial noisy latents
sigmas: Noise schedule
extra_args: Model-specific arguments
callback: Progress callback
disable: Disable progress bar
**kwargs: Sampler-specific options
Returns:
Denoised latents
"""
...
# ============================================================================
# CFG SCHEDULER PROTOCOLS
# ============================================================================
@runtime_checkable
class CFGSchedulerProtocol(Protocol):
"""Protocol for CFG scheduling strategies."""
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
"""Get CFG scale for a given step.
Args:
step: Current step (0-indexed)
total_steps: Total number of steps
base_cfg: Base CFG value
Returns:
CFG scale to use for this step
"""
...
class ConstantCFGScheduler:
"""Default CFG scheduler - constant value throughout."""
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
return base_cfg
class CFGFreeScheduler:
"""CFG-free sampling - drops to CFG=1 after a percentage of steps."""
def __init__(self, start_percent: float = 70.0):
self.start_percent = start_percent
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
progress = (step / max(1, total_steps - 1)) * 100
if progress >= self.start_percent:
return 1.0
return base_cfg
class LinearDecayCFGScheduler:
"""Linearly decay CFG from base to target."""
def __init__(self, target_cfg: float = 1.0):
self.target_cfg = target_cfg
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
progress = step / max(1, total_steps - 1)
return base_cfg + (self.target_cfg - base_cfg) * progress
# ============================================================================
# PIPELINE PROTOCOL
# ============================================================================
@runtime_checkable
class PipelineProtocol(Protocol):
"""Protocol for the main generation pipeline."""
def run(self, ctx: "Context") -> "Context":
"""Execute the full generation pipeline.
Args:
ctx: Configured context with all parameters
Returns:
Context with generated images in current_image
"""
...
# ============================================================================
# TYPE ALIASES
# ============================================================================
# Conditioning tuple type (used throughout the codebase)
Conditioning = list[tuple[torch.Tensor, dict[str, Any]]]
# Latent dict type
LatentDict = dict[str, torch.Tensor]
|