LightDiffusion-Next / src /Core /Interfaces.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Protocol definitions for LightDiffusion-Next.
This module defines the contracts (interfaces) that all components must follow.
Using Protocol allows for structural subtyping without requiring explicit inheritance.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
import torch
if TYPE_CHECKING:
from src.Core.Context import Context
# ============================================================================
# MODEL PROTOCOLS
# ============================================================================
@dataclass
class ModelCapabilities:
"""Describes what a model implementation can do."""
min_resolution: int = 256
max_resolution: int = 2048
preferred_resolution: int = 512
requires_resolution_multiple: int = 64
supports_hires_fix: bool = True
supports_img2img: bool = True
supports_inpainting: bool = False
supports_controlnet: bool = False
supports_stable_fast: bool = True
supports_deepcache: bool = True
supports_tome: bool = True
uses_dual_clip: bool = False
requires_size_conditioning: bool = False
def validate_resolution(self, width: int, height: int) -> tuple[int, int]:
"""Clamp and round resolution to model requirements."""
width = max(self.min_resolution, min(width, self.max_resolution))
height = max(self.min_resolution, min(height, self.max_resolution))
width = (width // self.requires_resolution_multiple) * self.requires_resolution_multiple
height = (height // self.requires_resolution_multiple) * self.requires_resolution_multiple
return width, height
@runtime_checkable
class ModelProtocol(Protocol):
"""Protocol defining the contract for all model implementations."""
model: Any
clip: Any
vae: Any
model_path: str
@property
def capabilities(self) -> ModelCapabilities:
"""Return model capabilities."""
...
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
def load(self, model_path: str = None) -> "ModelProtocol":
"""Load model from disk."""
...
def encode_prompt(
self,
prompt: str | list[str],
negative_prompt: str | list[str] = "",
clip_skip: int = -2,
) -> tuple[Any, Any]:
"""Encode prompts to conditioning."""
...
def generate(
self,
ctx: "Context",
positive: Any,
negative: Any,
) -> dict:
"""Generate latents."""
...
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode latents to images."""
...
def apply_lora(
self,
lora_name: str,
strength_model: float = 1.0,
strength_clip: float = 1.0,
) -> "ModelProtocol":
"""Apply LoRA weights."""
...
def apply_stable_fast(self, enable_cuda_graph: bool = True) -> "ModelProtocol":
"""Apply StableFast optimization."""
...
def apply_deepcache(
self,
cache_interval: int = 3,
cache_depth: int = 2,
start_step: int = 0,
end_step: int = 1000,
) -> "ModelProtocol":
"""Apply DeepCache optimization."""
...
def unload(self) -> None:
"""Release model resources."""
...
# ============================================================================
# PROCESSOR PROTOCOLS
# ============================================================================
@runtime_checkable
class ProcessorProtocol(Protocol):
"""Protocol for pipeline processors (plugins).
Processors are stateless components that can optionally modify
the pipeline context based on feature flags.
"""
@staticmethod
def is_enabled(ctx: "Context") -> bool:
"""Check if this processor should run for given context."""
...
@staticmethod
def process(
ctx: "Context",
model: ModelProtocol,
**kwargs
) -> "Context":
"""Process the context, potentially modifying latents/images.
Args:
ctx: Pipeline context (may be modified)
model: Loaded model for any re-sampling needed
**kwargs: Processor-specific arguments
Returns:
Modified context
"""
...
class BaseProcessor(ABC):
"""Abstract base class for processors providing common functionality."""
@staticmethod
@abstractmethod
def is_enabled(ctx: "Context") -> bool:
"""Check if this processor should run."""
pass
@staticmethod
@abstractmethod
def process(ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
"""Process the context."""
pass
@classmethod
def run_if_enabled(cls, ctx: "Context", model: ModelProtocol, **kwargs) -> "Context":
"""Convenience method to conditionally run processor."""
if cls.is_enabled(ctx):
return cls.process(ctx, model, **kwargs)
return ctx
# ============================================================================
# SAMPLER PROTOCOLS
# ============================================================================
@runtime_checkable
class SamplerProtocol(Protocol):
"""Protocol for diffusion samplers."""
def sample(
self,
model: Any,
x: torch.Tensor,
sigmas: torch.Tensor,
extra_args: dict = None,
callback: Any = None,
disable: bool = None,
**kwargs,
) -> torch.Tensor:
"""Run the sampling loop.
Args:
model: The denoising model
x: Initial noisy latents
sigmas: Noise schedule
extra_args: Model-specific arguments
callback: Progress callback
disable: Disable progress bar
**kwargs: Sampler-specific options
Returns:
Denoised latents
"""
...
# ============================================================================
# CFG SCHEDULER PROTOCOLS
# ============================================================================
@runtime_checkable
class CFGSchedulerProtocol(Protocol):
"""Protocol for CFG scheduling strategies."""
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
"""Get CFG scale for a given step.
Args:
step: Current step (0-indexed)
total_steps: Total number of steps
base_cfg: Base CFG value
Returns:
CFG scale to use for this step
"""
...
class ConstantCFGScheduler:
"""Default CFG scheduler - constant value throughout."""
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
return base_cfg
class CFGFreeScheduler:
"""CFG-free sampling - drops to CFG=1 after a percentage of steps."""
def __init__(self, start_percent: float = 70.0):
self.start_percent = start_percent
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
progress = (step / max(1, total_steps - 1)) * 100
if progress >= self.start_percent:
return 1.0
return base_cfg
class LinearDecayCFGScheduler:
"""Linearly decay CFG from base to target."""
def __init__(self, target_cfg: float = 1.0):
self.target_cfg = target_cfg
def get_cfg(self, step: int, total_steps: int, base_cfg: float) -> float:
progress = step / max(1, total_steps - 1)
return base_cfg + (self.target_cfg - base_cfg) * progress
# ============================================================================
# PIPELINE PROTOCOL
# ============================================================================
@runtime_checkable
class PipelineProtocol(Protocol):
"""Protocol for the main generation pipeline."""
def run(self, ctx: "Context") -> "Context":
"""Execute the full generation pipeline.
Args:
ctx: Configured context with all parameters
Returns:
Context with generated images in current_image
"""
...
# ============================================================================
# TYPE ALIASES
# ============================================================================
# Conditioning tuple type (used throughout the codebase)
Conditioning = list[tuple[torch.Tensor, dict[str, Any]]]
# Latent dict type
LatentDict = dict[str, torch.Tensor]