Spaces:
Running on Zero
Running on Zero
| """SD1.5 model adapter for LightDiffusion-Next. | |
| Provides a clean interface to the SD1.5 model that inherits from | |
| AbstractModel and wraps the existing infrastructure. | |
| """ | |
| import logging | |
| from typing import TYPE_CHECKING, Any, Callable, Optional | |
| import torch | |
| from src.Core.AbstractModel import AbstractModel, ModelCapabilities | |
| if TYPE_CHECKING: | |
| from src.Core.Context import Context | |
| class SD15Model(AbstractModel): | |
| """SD1.5 model implementation. | |
| Wraps the existing SD15 model loading and inference code | |
| with the clean AbstractModel interface. | |
| """ | |
| def __init__(self, model_path: str = None): | |
| """Initialize the SD15 model adapter. | |
| Args: | |
| model_path: Path to the model checkpoint (safetensors/pt) | |
| """ | |
| super().__init__(model_path) | |
| self._clip_skip = -2 | |
| def _create_capabilities(self) -> ModelCapabilities: | |
| """Create capabilities for SD1.5 models.""" | |
| return ModelCapabilities( | |
| min_resolution=256, | |
| max_resolution=2048, | |
| preferred_resolution=512, | |
| requires_resolution_multiple=64, | |
| supports_hires_fix=True, | |
| supports_img2img=True, | |
| supports_inpainting=True, | |
| supports_controlnet=True, | |
| supports_stable_fast=True, | |
| supports_deepcache=True, | |
| supports_tome=True, | |
| uses_dual_clip=False, | |
| requires_size_conditioning=False, | |
| ) | |
| def load(self, model_path: str = None) -> "SD15Model": | |
| """Load the SD1.5 model from disk. | |
| Args: | |
| model_path: Optional override for the model path | |
| Returns: | |
| Self for method chaining | |
| """ | |
| logger = logging.getLogger(__name__) | |
| path = model_path or self.model_path | |
| if path is None: | |
| # Use default checkpoint | |
| path = "./include/checkpoints/DreamShaper_8_pruned.safetensors" | |
| # Guard: Don't reload if already loaded with same path | |
| if self._loaded and self.model_path == path: | |
| logger.info(f"SD15Model: Already loaded {path}, skipping redundant load") | |
| return self | |
| self.model_path = path | |
| try: | |
| from src.FileManaging import Loader | |
| loader = Loader.CheckpointLoaderSimple() | |
| result = loader.load_checkpoint(ckpt_name=path) | |
| self.model = result[0] | |
| self.clip = result[1] | |
| self.vae = result[2] | |
| self._loaded = True | |
| logger.info(f"SD15Model: loaded {path}") | |
| except Exception as e: | |
| logger.exception(f"SD15Model: failed to load {path}: {e}") | |
| raise | |
| return self | |
| def get_model_object(self, name: str) -> Any: | |
| """Get an attribute from the underlying model.""" | |
| if self.model: | |
| return self.model.get_model_object(name) | |
| return None | |
| def encode_prompt( | |
| self, | |
| prompt: str | list[str], | |
| negative_prompt: str | list[str] = "", | |
| clip_skip: int = None, | |
| ) -> tuple[Any, Any]: | |
| """Encode text prompts into conditioning tensors. | |
| Args: | |
| prompt: Positive prompt(s) to encode | |
| negative_prompt: Negative prompt(s) to encode | |
| clip_skip: Number of CLIP layers to skip (default: -2) | |
| Returns: | |
| Tuple of (positive_conditioning, negative_conditioning) | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before encoding prompts") | |
| clip_skip = clip_skip if clip_skip is not None else self._clip_skip | |
| try: | |
| from src.clip import Clip | |
| # Apply CLIP skip | |
| clip_layer = Clip.CLIPSetLastLayer() | |
| processed_clip = clip_layer.set_last_layer( | |
| stop_at_clip_layer=clip_skip, | |
| clip=self.clip, | |
| )[0] | |
| # Encode prompts | |
| encoder = Clip.CLIPTextEncode() | |
| positive = encoder.encode( | |
| text=prompt, | |
| clip=processed_clip, | |
| )[0] | |
| negative = encoder.encode( | |
| text=negative_prompt, | |
| clip=processed_clip, | |
| )[0] | |
| return positive, negative | |
| except Exception as e: | |
| logging.getLogger(__name__).exception(f"Prompt encoding failed: {e}") | |
| raise | |
| def generate( | |
| self, | |
| ctx: "Context", | |
| positive: Any, | |
| negative: Any, | |
| latent_image: Optional[Any] = None, | |
| start_step: Optional[int] = None, | |
| last_step: Optional[int] = None, | |
| disable_noise: bool = False, | |
| callback: Optional[Callable] = None, | |
| ) -> dict: | |
| """Generate latents using the sampler. | |
| Args: | |
| ctx: Pipeline context with generation parameters | |
| positive: Positive conditioning | |
| negative: Negative conditioning | |
| latent_image: Optional existing latent to continue from | |
| start_step: Optional step to start sampling from | |
| last_step: Optional step to stop sampling at | |
| Returns: | |
| Dictionary with 'samples' key containing generated latents | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before generating") | |
| try: | |
| from src.sample import sampling | |
| from src.Utilities import Latent | |
| from src.hidiffusion import msw_msa_attention | |
| # Use provided latent or create empty one | |
| if latent_image is not None: | |
| latent = latent_image | |
| else: | |
| # Create empty latent | |
| latent_gen = Latent.EmptyLatentImage() | |
| latent = latent_gen.generate( | |
| width=ctx.generation.width, | |
| height=ctx.generation.height, | |
| batch_size=ctx.generation.batch, | |
| )[0] | |
| # Add seeds for deterministic noise | |
| latent["seeds"] = ctx.seeds[:ctx.generation.batch] if ctx.seeds else [ctx.seed] | |
| # Apply HiDiffusion optimization only for very high resolutions | |
| if ctx.generation.width > 2048 or ctx.generation.height > 2048: | |
| try: | |
| # Clone model before patching | |
| patch_model = self.model.clone() | |
| hidiff = msw_msa_attention.ApplyMSWMSAAttentionSimple() | |
| optimized_model = hidiff.go(model_type="sd15", model=patch_model)[0] | |
| except Exception: | |
| optimized_model = self.model | |
| else: | |
| optimized_model = self.model | |
| # Run sampling | |
| ksampler = sampling.KSampler() | |
| result = ksampler.sample( | |
| seed=ctx.seed, | |
| steps=ctx.sampling.steps, | |
| cfg=ctx.sampling.cfg, | |
| sampler_name=ctx.sampling.sampler, | |
| scheduler=ctx.sampling.scheduler, | |
| denoise=ctx.sampling.denoise, | |
| pipeline=True, | |
| model=optimized_model, | |
| positive=positive, | |
| negative=negative, | |
| latent_image=latent, | |
| start_step=start_step, | |
| last_step=last_step, | |
| disable_noise=disable_noise, | |
| callback=callback or ctx.callback, | |
| enable_multiscale=ctx.sampling.enable_multiscale, | |
| multiscale_factor=ctx.sampling.multiscale_factor, | |
| multiscale_fullres_start=ctx.sampling.multiscale_fullres_start, | |
| multiscale_fullres_end=ctx.sampling.multiscale_fullres_end, | |
| multiscale_intermittent_fullres=ctx.sampling.multiscale_intermittent_fullres, | |
| cfg_free_enabled=ctx.sampling.cfg_free_enabled, | |
| cfg_free_start_percent=ctx.sampling.cfg_free_start_percent, | |
| batched_cfg=ctx.sampling.batched_cfg, | |
| dynamic_cfg_rescaling=ctx.sampling.dynamic_cfg_rescaling, | |
| dynamic_cfg_method=ctx.sampling.dynamic_cfg_method, | |
| dynamic_cfg_percentile=ctx.sampling.dynamic_cfg_percentile, | |
| dynamic_cfg_target_scale=ctx.sampling.dynamic_cfg_target_scale, | |
| adaptive_noise_enabled=ctx.sampling.adaptive_noise_enabled, | |
| adaptive_noise_method=ctx.sampling.adaptive_noise_method, | |
| ) | |
| return result[0] | |
| except Exception as e: | |
| logging.getLogger(__name__).exception(f"Generation failed: {e}") | |
| raise | |
| def decode(self, latents: torch.Tensor) -> torch.Tensor: | |
| """Decode latents to pixel space. | |
| Args: | |
| latents: Latent tensor or dict with 'samples' key | |
| Returns: | |
| Decoded image tensor in [0, 1] range | |
| """ | |
| if not self._loaded: | |
| raise RuntimeError("Model must be loaded before decoding") | |
| try: | |
| from src.AutoEncoders import VariationalAE | |
| decoder = VariationalAE.VAEDecode() | |
| # Handle both raw tensor and dict input | |
| if isinstance(latents, dict): | |
| samples = latents | |
| else: | |
| samples = {"samples": latents} | |
| result = decoder.decode( | |
| samples=samples, | |
| vae=self.vae, | |
| ) | |
| return result[0] | |
| except Exception as e: | |
| logging.getLogger(__name__).exception(f"Decoding failed: {e}") | |
| raise | |