Spaces:
Running on Zero
Running on Zero
| """High-resolution fix processor for LightDiffusion-Next. | |
| This processor upscales latents and runs an additional diffusion pass | |
| to enhance detail at higher resolutions. | |
| """ | |
| import logging | |
| import random | |
| from typing import TYPE_CHECKING, Any, Optional, Callable | |
| import torch | |
| if TYPE_CHECKING: | |
| from src.Core.PipelineContext import PipelineContext | |
| from src.Core.AbstractModel import AbstractModel | |
| class HiresFix: | |
| """High-resolution fix processor. | |
| Upscales latents in latent space and runs additional sampling | |
| to enhance details at the higher resolution. | |
| """ | |
| # Default settings | |
| DEFAULT_SCALE = 2.0 | |
| DEFAULT_DENOISE = 0.35 | |
| DEFAULT_STEPS_RATIO = 0.5 | |
| DEFAULT_CFG = 8 | |
| def apply( | |
| cls, | |
| latents: dict, | |
| ctx: "PipelineContext", | |
| model: "AbstractModel", | |
| positive: Any, | |
| negative: Any, | |
| scale: float = None, | |
| denoise: float = None, | |
| steps: int = None, | |
| callback: Optional[Callable] = None, | |
| ) -> dict: | |
| """Apply high-resolution fix to latents. | |
| Args: | |
| latents: Dictionary containing 'samples' key with latent tensor | |
| ctx: Pipeline context with configuration | |
| model: The loaded model instance | |
| positive: Positive conditioning | |
| negative: Negative conditioning | |
| scale: Upscale factor (default: 2.0) | |
| denoise: Denoising strength (default: 0.45) | |
| steps: Number of sampling steps (default: 50% of original) | |
| callback: Optional callback for live previews | |
| Returns: | |
| Dictionary with upscaled and refined latents | |
| """ | |
| logger = logging.getLogger(__name__) | |
| # Check if model supports hires fix | |
| if not model.capabilities.supports_hires_fix: | |
| logger.warning("Model does not support HiresFix, returning original latents") | |
| return latents | |
| # Determine model flags | |
| is_flux = getattr(model.capabilities, "is_flux", False) | |
| is_flux2 = getattr(model.capabilities, "is_flux2", False) | |
| # Use defaults if not specified | |
| scale = scale or cls.DEFAULT_SCALE | |
| # Use a hires-specific context for hires pass (centralizes defaults) | |
| hires_ctx = ctx.with_hires_settings(scale) | |
| # Calculate steps - for Flux2 Klein (distilled), we can use fewer steps | |
| min_steps = 3 if is_flux2 else 10 | |
| steps = steps or max(min_steps, int(hires_ctx.sampling.steps)) | |
| # Respect denoise default from hires context unless explicitly overridden | |
| denoise = denoise or hires_ctx.sampling.denoise | |
| # For Flux models, prefer the user's cfg from the original context (pipeline caps apply elsewhere) | |
| if is_flux or is_flux2: | |
| hires_cfg = ctx.sampling.cfg | |
| else: | |
| hires_cfg = hires_ctx.sampling.cfg | |
| try: | |
| # Import required modules | |
| from src.Utilities import upscale as upscale_module | |
| from src.sample import sampling | |
| from src.hidiffusion import msw_msa_attention | |
| # Calculate new dimensions from hires context | |
| new_width = int(hires_ctx.generation.width) | |
| new_height = int(hires_ctx.generation.height) | |
| # Get model-specific downscale factor (e.g., 8 for SD, 16 for Flux) | |
| downscale_factor = 8 | |
| try: | |
| latent_format = model.get_model_object("latent_format") | |
| if hasattr(latent_format, "downscale_factor"): | |
| downscale_factor = latent_format.downscale_factor | |
| elif hasattr(latent_format, "spacial_downscale_ratio"): | |
| downscale_factor = latent_format.spacial_downscale_ratio | |
| except Exception: | |
| pass | |
| # Validate against model capabilities | |
| new_width, new_height = model.capabilities.validate_resolution(new_width, new_height) | |
| logger.info(f"HiresFix: upscaling from {ctx.generation.width}x{ctx.generation.height} to {new_width}x{new_height}") | |
| # Upscale latents | |
| latent_upscale = upscale_module.LatentUpscale() | |
| upscaled = latent_upscale.upscale( | |
| samples=latents, | |
| width=new_width, | |
| height=new_height, | |
| downscale_factor=downscale_factor, | |
| )[0] | |
| # Generate new seed for hires pass (PyTorch max: 2**63 - 1) | |
| hires_seed = random.randint(1, 2**63 - 1) | |
| # Apply HiDiffusion optimizer only for very high resolutions (>2048px) | |
| # This avoids the grid/weave artifacts reported at standard hires sizes | |
| if not is_flux and (new_width > 2048 or new_height > 2048): | |
| try: | |
| hidiff_optimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple() | |
| optimized_model = hidiff_optimizer.go(model_type="auto", model=model.model)[0] | |
| logger.info("HiresFix: Applied HiDiffusion optimization for extreme resolution") | |
| except Exception: | |
| optimized_model = model.model | |
| else: | |
| optimized_model = model.model | |
| # Create sampler and run hires pass | |
| ksampler = sampling.KSampler() | |
| # If model requires resolution-aware conditioning (e.g., SDXL), adjust prompts/conds | |
| try: | |
| if getattr(model.capabilities, "requires_size_conditioning", False): | |
| # Re-encode prompts if raw text was provided | |
| def _is_encoded_list(obj): | |
| return isinstance(obj, (list, tuple)) and len(obj) > 0 and isinstance(obj[0], (list, tuple)) and isinstance(obj[0][1], dict) | |
| if isinstance(positive, (str, list)) and not _is_encoded_list(positive): | |
| positive, negative = model.encode_prompt(ctx.prompt, ctx.negative_prompt) | |
| # Recursively update width/height in any meta dicts | |
| def _update_meta(obj): | |
| if isinstance(obj, (list, tuple)): | |
| for item in obj: | |
| if isinstance(item, (list, tuple)) and len(item) > 1 and isinstance(item[1], dict): | |
| item[1].update({ | |
| "width": new_width, | |
| "height": new_height, | |
| "crop_w": 0, | |
| "crop_h": 0, | |
| "target_width": new_width, | |
| "target_height": new_height, | |
| }) | |
| else: | |
| _update_meta(item) | |
| _update_meta(positive) | |
| _update_meta(negative) | |
| except Exception: | |
| pass | |
| hires_result = ksampler.sample( | |
| seed=hires_seed, | |
| steps=steps, | |
| cfg=hires_cfg, | |
| sampler_name=hires_ctx.sampling.sampler, | |
| scheduler=hires_ctx.sampling.scheduler, | |
| denoise=denoise, | |
| model=optimized_model, | |
| positive=positive, | |
| negative=negative, | |
| latent_image=upscaled, | |
| pipeline=True, | |
| flux=is_flux, | |
| flux2=is_flux2, | |
| # CRITICAL: Always disable multi-scale for the hires pass itself | |
| # Multi-scale downscales during sampling, which defeats the purpose of hires fix | |
| # and can introduce blurriness or artifacts. | |
| enable_multiscale=False, | |
| cfg_free_enabled=hires_ctx.sampling.cfg_free_enabled, | |
| cfg_free_start_percent=hires_ctx.sampling.cfg_free_start_percent, | |
| batched_cfg=hires_ctx.sampling.batched_cfg, | |
| dynamic_cfg_rescaling=hires_ctx.sampling.dynamic_cfg_rescaling, | |
| dynamic_cfg_method=hires_ctx.sampling.dynamic_cfg_method, | |
| dynamic_cfg_percentile=hires_ctx.sampling.dynamic_cfg_percentile, | |
| dynamic_cfg_target_scale=hires_ctx.sampling.dynamic_cfg_target_scale, | |
| callback=callback, | |
| ) | |
| logger.info("HiresFix: completed successfully") | |
| return hires_result[0] | |
| except Exception as e: | |
| logger.exception(f"HiresFix failed: {e}") | |
| # Return original latents on failure | |
| return latents | |
| def apply_to_image( | |
| cls, | |
| image: torch.Tensor, | |
| ctx: "PipelineContext", | |
| model: "AbstractModel", | |
| positive: Any, | |
| negative: Any, | |
| scale: float = None, | |
| callback: Optional[Callable] = None, | |
| ) -> torch.Tensor: | |
| """Apply high-resolution fix starting from a decoded image. | |
| This encodes the image to latents, applies hires fix, then decodes. | |
| Args: | |
| image: Image tensor in [0, 1] range | |
| ctx: Pipeline context | |
| model: The loaded model | |
| positive: Positive conditioning | |
| negative: Negative conditioning | |
| scale: Upscale factor | |
| callback: Optional callback for live previews | |
| Returns: | |
| Enhanced image tensor | |
| """ | |
| logger = logging.getLogger(__name__) | |
| try: | |
| # Encode image to latents | |
| from src.AutoEncoders import VariationalAE | |
| vae_encode = VariationalAE.VAEEncode() | |
| latents = vae_encode.encode(vae=model.vae, pixels=image)[0] | |
| # Apply hires fix | |
| enhanced_latents = cls.apply( | |
| latents=latents, | |
| ctx=ctx, | |
| model=model, | |
| positive=positive, | |
| negative=negative, | |
| scale=scale, | |
| callback=callback, | |
| ) | |
| # Decode back to image | |
| return model.decode(enhanced_latents["samples"]) | |
| except Exception as e: | |
| logger.exception(f"HiresFix (image mode) failed: {e}") | |
| return image | |
| def is_enabled(cls, ctx: "PipelineContext") -> bool: | |
| """Check if HiresFix should be applied based on context. | |
| Args: | |
| ctx: Pipeline context | |
| Returns: | |
| True if HiresFix should be applied | |
| """ | |
| return ctx.features.hires_fix | |