"""High-resolution fix processor for LightDiffusion-Next. This processor upscales latents and runs an additional diffusion pass to enhance detail at higher resolutions. """ import logging import random from typing import TYPE_CHECKING, Any, Optional, Callable import torch if TYPE_CHECKING: from src.Core.PipelineContext import PipelineContext from src.Core.AbstractModel import AbstractModel class HiresFix: """High-resolution fix processor. Upscales latents in latent space and runs additional sampling to enhance details at the higher resolution. """ # Default settings DEFAULT_SCALE = 2.0 DEFAULT_DENOISE = 0.35 DEFAULT_STEPS_RATIO = 0.5 DEFAULT_CFG = 8 @classmethod def apply( cls, latents: dict, ctx: "PipelineContext", model: "AbstractModel", positive: Any, negative: Any, scale: float = None, denoise: float = None, steps: int = None, callback: Optional[Callable] = None, ) -> dict: """Apply high-resolution fix to latents. Args: latents: Dictionary containing 'samples' key with latent tensor ctx: Pipeline context with configuration model: The loaded model instance positive: Positive conditioning negative: Negative conditioning scale: Upscale factor (default: 2.0) denoise: Denoising strength (default: 0.45) steps: Number of sampling steps (default: 50% of original) callback: Optional callback for live previews Returns: Dictionary with upscaled and refined latents """ logger = logging.getLogger(__name__) # Check if model supports hires fix if not model.capabilities.supports_hires_fix: logger.warning("Model does not support HiresFix, returning original latents") return latents # Determine model flags is_flux = getattr(model.capabilities, "is_flux", False) is_flux2 = getattr(model.capabilities, "is_flux2", False) # Use defaults if not specified scale = scale or cls.DEFAULT_SCALE # Use a hires-specific context for hires pass (centralizes defaults) hires_ctx = ctx.with_hires_settings(scale) # Calculate steps - for Flux2 Klein (distilled), we can use fewer steps min_steps = 3 if is_flux2 else 10 steps = steps or max(min_steps, int(hires_ctx.sampling.steps)) # Respect denoise default from hires context unless explicitly overridden denoise = denoise or hires_ctx.sampling.denoise # For Flux models, prefer the user's cfg from the original context (pipeline caps apply elsewhere) if is_flux or is_flux2: hires_cfg = ctx.sampling.cfg else: hires_cfg = hires_ctx.sampling.cfg try: # Import required modules from src.Utilities import upscale as upscale_module from src.sample import sampling from src.hidiffusion import msw_msa_attention # Calculate new dimensions from hires context new_width = int(hires_ctx.generation.width) new_height = int(hires_ctx.generation.height) # Get model-specific downscale factor (e.g., 8 for SD, 16 for Flux) downscale_factor = 8 try: latent_format = model.get_model_object("latent_format") if hasattr(latent_format, "downscale_factor"): downscale_factor = latent_format.downscale_factor elif hasattr(latent_format, "spacial_downscale_ratio"): downscale_factor = latent_format.spacial_downscale_ratio except Exception: pass # Validate against model capabilities new_width, new_height = model.capabilities.validate_resolution(new_width, new_height) logger.info(f"HiresFix: upscaling from {ctx.generation.width}x{ctx.generation.height} to {new_width}x{new_height}") # Upscale latents latent_upscale = upscale_module.LatentUpscale() upscaled = latent_upscale.upscale( samples=latents, width=new_width, height=new_height, downscale_factor=downscale_factor, )[0] # Generate new seed for hires pass (PyTorch max: 2**63 - 1) hires_seed = random.randint(1, 2**63 - 1) # Apply HiDiffusion optimizer only for very high resolutions (>2048px) # This avoids the grid/weave artifacts reported at standard hires sizes if not is_flux and (new_width > 2048 or new_height > 2048): try: hidiff_optimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple() optimized_model = hidiff_optimizer.go(model_type="auto", model=model.model)[0] logger.info("HiresFix: Applied HiDiffusion optimization for extreme resolution") except Exception: optimized_model = model.model else: optimized_model = model.model # Create sampler and run hires pass ksampler = sampling.KSampler() # If model requires resolution-aware conditioning (e.g., SDXL), adjust prompts/conds try: if getattr(model.capabilities, "requires_size_conditioning", False): # Re-encode prompts if raw text was provided def _is_encoded_list(obj): return isinstance(obj, (list, tuple)) and len(obj) > 0 and isinstance(obj[0], (list, tuple)) and isinstance(obj[0][1], dict) if isinstance(positive, (str, list)) and not _is_encoded_list(positive): positive, negative = model.encode_prompt(ctx.prompt, ctx.negative_prompt) # Recursively update width/height in any meta dicts def _update_meta(obj): if isinstance(obj, (list, tuple)): for item in obj: if isinstance(item, (list, tuple)) and len(item) > 1 and isinstance(item[1], dict): item[1].update({ "width": new_width, "height": new_height, "crop_w": 0, "crop_h": 0, "target_width": new_width, "target_height": new_height, }) else: _update_meta(item) _update_meta(positive) _update_meta(negative) except Exception: pass hires_result = ksampler.sample( seed=hires_seed, steps=steps, cfg=hires_cfg, sampler_name=hires_ctx.sampling.sampler, scheduler=hires_ctx.sampling.scheduler, denoise=denoise, model=optimized_model, positive=positive, negative=negative, latent_image=upscaled, pipeline=True, flux=is_flux, flux2=is_flux2, # CRITICAL: Always disable multi-scale for the hires pass itself # Multi-scale downscales during sampling, which defeats the purpose of hires fix # and can introduce blurriness or artifacts. enable_multiscale=False, cfg_free_enabled=hires_ctx.sampling.cfg_free_enabled, cfg_free_start_percent=hires_ctx.sampling.cfg_free_start_percent, batched_cfg=hires_ctx.sampling.batched_cfg, dynamic_cfg_rescaling=hires_ctx.sampling.dynamic_cfg_rescaling, dynamic_cfg_method=hires_ctx.sampling.dynamic_cfg_method, dynamic_cfg_percentile=hires_ctx.sampling.dynamic_cfg_percentile, dynamic_cfg_target_scale=hires_ctx.sampling.dynamic_cfg_target_scale, callback=callback, ) logger.info("HiresFix: completed successfully") return hires_result[0] except Exception as e: logger.exception(f"HiresFix failed: {e}") # Return original latents on failure return latents @classmethod def apply_to_image( cls, image: torch.Tensor, ctx: "PipelineContext", model: "AbstractModel", positive: Any, negative: Any, scale: float = None, callback: Optional[Callable] = None, ) -> torch.Tensor: """Apply high-resolution fix starting from a decoded image. This encodes the image to latents, applies hires fix, then decodes. Args: image: Image tensor in [0, 1] range ctx: Pipeline context model: The loaded model positive: Positive conditioning negative: Negative conditioning scale: Upscale factor callback: Optional callback for live previews Returns: Enhanced image tensor """ logger = logging.getLogger(__name__) try: # Encode image to latents from src.AutoEncoders import VariationalAE vae_encode = VariationalAE.VAEEncode() latents = vae_encode.encode(vae=model.vae, pixels=image)[0] # Apply hires fix enhanced_latents = cls.apply( latents=latents, ctx=ctx, model=model, positive=positive, negative=negative, scale=scale, callback=callback, ) # Decode back to image return model.decode(enhanced_latents["samples"]) except Exception as e: logger.exception(f"HiresFix (image mode) failed: {e}") return image @classmethod def is_enabled(cls, ctx: "PipelineContext") -> bool: """Check if HiresFix should be applied based on context. Args: ctx: Pipeline context Returns: True if HiresFix should be applied """ return ctx.features.hires_fix