Spaces:
Running on Zero
Running on Zero
| """Protocol definitions for LightDiffusion-Next. | |
| This module defines the contracts (interfaces) that all components must follow. | |
| Using Protocol allows for structural subtyping without requiring explicit inheritance. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable | |
| import torch | |
| if TYPE_CHECKING: | |
| from src.Core.Context import Context | |
| # ============================================================================ | |
| # MODEL PROTOCOLS | |
| # ============================================================================ | |
| class ModelCapabilities: | |
| """Describes what a model implementation can do.""" | |
| min_resolution: int = 256 | |
| max_resolution: int = 2048 | |
| preferred_resolution: int = 512 | |
| requires_resolution_multiple: int = 64 | |
| supports_hires_fix: bool = True | |
| supports_img2img: bool = True | |
| supports_inpainting: bool = False | |
| supports_controlnet: bool = False | |
| supports_stable_fast: bool = True | |
| supports_deepcache: bool = True | |
| supports_tome: bool = True | |
| uses_dual_clip: bool = False | |
| requires_size_conditioning: bool = False | |
| def validate_resolution(self, width: int, height: int) -> tuple[int, int]: | |
| """Clamp and round resolution to model requirements.""" | |
| width = max(self.min_resolution, min(width, self.max_resolution)) | |
| height = max(self.min_resolution, min(height, self.max_resolution)) | |
| width = (width // self.requires_resolution_multiple) * self.requires_resolution_multiple | |
| height = (height // self.requires_resolution_multiple) * self.requires_resolution_multiple | |
| return width, height | |
| class ModelProtocol(Protocol): | |
| """Protocol defining the contract for all model implementations.""" | |
| model: Any | |
| clip: Any | |
| vae: Any | |
| model_path: str | |
| def capabilities(self) -> ModelCapabilities: | |
| """Return model capabilities.""" | |
| ... | |
| def is_loaded(self) -> bool: | |
| """Check if model is loaded.""" | |
| ... | |
| def load(self, model_path: str = None) -> "ModelProtocol": | |
| """Load model from disk.""" | |
| ... | |
| def encode_prompt( | |
| self, | |
| prompt: str | list[str], | |
| negative_prompt: str | list[str] = "", | |
| clip_skip: int = -2, | |
| ) -> tuple[Any, Any]: | |
| """Encode prompts to conditioning.""" | |
| ... | |
| def generate( | |
| self, | |
| ctx: "Context", | |
| positive: Any, | |
| negative: Any, | |
| ) -> dict: | |
| """Generate latents.""" | |
| ... | |
| def decode(self, latents: torch.Tensor) -> torch.Tensor: | |
| """Decode latents to images.""" | |
| ... | |
| def apply_lora( | |
| self, | |
| lora_name: str, | |
| strength_model: float = 1.0, | |
| strength_clip: float = 1.0, | |
| ) -> "ModelProtocol": | |
| """Apply LoRA weights.""" | |
| ... | |
| def apply_stable_fast(self, enable_cuda_graph: bool = True) -> "ModelProtocol": | |
| """Apply StableFast optimization.""" | |
| ... | |
| def apply_deepcache( | |
| self, | |
| cache_interval: int = 3, | |
| cache_depth: int = 2, | |
| start_step: int = 0, | |
| end_step: int = 1000, | |
| ) -> "ModelProtocol": | |
| """Apply DeepCache optimization.""" | |
| ... | |
| def unload(self) -> None: | |
| """Release model resources.""" | |
| ... | |
| # ============================================================================ | |
| # PROCESSOR PROTOCOLS | |
| # ============================================================================ | |
| class ProcessorProtocol(Protocol): | |
| """Protocol for pipeline processors (plugins). | |
| Processors are stateless components that can optionally modify | |
| the pipeline context based on feature flags. | |
| """ | |
| def is_enabled(ctx: "Context") -> bool: | |
| """Check if this processor should run for given context.""" | |
| ... | |
| def process( | |
| ctx: "Context", | |
| model: ModelProtocol, | |
| **kwargs | |
| ) -> "Context": | |
| """Process the context, potentially modifying latents/images. | |
| Args: | |
| ctx: Pipeline context (may be modified) | |
| model: Loaded model for any re-sampling needed | |
| **kwargs: Processor-specific arguments | |
| Returns: | |
| Modified context | |
| """ | |
| ... | |
| class BaseProcessor(ABC): | |
| """Abstract base class for processors providing common functionality.""" | |
| def is_enabled(ctx: "Context") -> bool: | |
| """Check if this processor should run.""" | |
| pass | |
| def process(ctx: "Context", model: ModelProtocol, **kwargs) -> "Context": | |
| """Process the context.""" | |
| pass | |
| def run_if_enabled(cls, ctx: "Context", model: ModelProtocol, **kwargs) -> "Context": | |
| """Convenience method to conditionally run processor.""" | |
| if cls.is_enabled(ctx): | |
| return cls.process(ctx, model, **kwargs) | |
| return ctx | |
| # ============================================================================ | |
| # SAMPLER PROTOCOLS | |
| # ============================================================================ | |
| class SamplerProtocol(Protocol): | |
| """Protocol for diffusion samplers.""" | |
| def sample( | |
| self, | |
| model: Any, | |
| x: torch.Tensor, | |
| sigmas: torch.Tensor, | |
| extra_args: dict = None, | |
| callback: Any = None, | |
| disable: bool = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """Run the sampling loop. | |
| Args: | |
| model: The denoising model | |
| x: Initial noisy latents | |
| sigmas: Noise schedule | |
| extra_args: Model-specific arguments | |
| callback: Progress callback | |
| disable: Disable progress bar | |
| **kwargs: Sampler-specific options | |
| Returns: | |
| Denoised latents | |
| """ | |
| ... | |
| # ============================================================================ | |
| # CFG SCHEDULER PROTOCOLS | |
| # ============================================================================ | |
| class CFGSchedulerProtocol(Protocol): | |
| """Protocol for CFG scheduling strategies.""" | |
| def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float: | |
| """Get CFG scale for a given step. | |
| Args: | |
| step: Current step (0-indexed) | |
| total_steps: Total number of steps | |
| base_cfg: Base CFG value | |
| Returns: | |
| CFG scale to use for this step | |
| """ | |
| ... | |
| class ConstantCFGScheduler: | |
| """Default CFG scheduler - constant value throughout.""" | |
| def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float: | |
| return base_cfg | |
| class CFGFreeScheduler: | |
| """CFG-free sampling - drops to CFG=1 after a percentage of steps.""" | |
| def __init__(self, start_percent: float = 70.0): | |
| self.start_percent = start_percent | |
| def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float: | |
| progress = (step / max(1, total_steps - 1)) * 100 | |
| if progress >= self.start_percent: | |
| return 1.0 | |
| return base_cfg | |
| class LinearDecayCFGScheduler: | |
| """Linearly decay CFG from base to target.""" | |
| def __init__(self, target_cfg: float = 1.0): | |
| self.target_cfg = target_cfg | |
| def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float: | |
| progress = step / max(1, total_steps - 1) | |
| return base_cfg + (self.target_cfg - base_cfg) * progress | |
| # ============================================================================ | |
| # PIPELINE PROTOCOL | |
| # ============================================================================ | |
| class PipelineProtocol(Protocol): | |
| """Protocol for the main generation pipeline.""" | |
| def run(self, ctx: "Context") -> "Context": | |
| """Execute the full generation pipeline. | |
| Args: | |
| ctx: Configured context with all parameters | |
| Returns: | |
| Context with generated images in current_image | |
| """ | |
| ... | |
| # ============================================================================ | |
| # TYPE ALIASES | |
| # ============================================================================ | |
| # Conditioning tuple type (used throughout the codebase) | |
| Conditioning = list[tuple[torch.Tensor, dict[str, Any]]] | |
| # Latent dict type | |
| LatentDict = dict[str, torch.Tensor] | |