import torch import torch.nn.functional as F import math import numpy as np from collections import OrderedDict # ======================================================================= # vUltimate - REAL Deep Code Audit # Based on v13 (THE_last_version) with CRITICAL FIXES # ======================================================================= # ======================================================================= # 1. SMART CACHING (Speed optimization ~15-20%) # ======================================================================= class SmartMaskCache: """LRU cache for blend masks to avoid regeneration""" def __init__(self, max_size=50): self.cache = OrderedDict() self.max_size = max_size def get(self, key): if key in self.cache: self.cache.move_to_end(key) return self.cache[key] return None def set(self, key, value): if key in self.cache: self.cache.move_to_end(key) self.cache[key] = value if len(self.cache) > self.max_size: self.cache.popitem(last=False) # Global cache instance _MASK_CACHE = SmartMaskCache() # ======================================================================= # 2. SAFE EPSILON (πŸ”΄ CRITICAL FIX: Infinite recursion bug in v13!) # ======================================================================= def get_safe_epsilon(tensor_or_dtype): """ Float16-safe epsilon - CRITICAL for half precision! πŸ”΄ vUltimate Fix: v13 had INFINITE RECURSION bug: Line 51: return get_safe_epsilon(torch.float16) # INFINITE LOOP! Args: tensor_or_dtype: torch.Tensor or torch.dtype Returns: float: safe epsilon for given dtype """ if isinstance(tensor_or_dtype, torch.Tensor): dtype = tensor_or_dtype.dtype else: dtype = tensor_or_dtype # Float16 minimum value ~6e-5, so 1e-6 causes underflow if dtype in (torch.float16, torch.bfloat16): return 1e-3 # Safe for half precision elif dtype == torch.float32: return 1e-6 # πŸ”΄ FIX: Changed from recursive call to direct value else: return 1e-12 # High precision for float64 # ======================================================================= # 3. LATENT COLOR FIX (Variance-preserving blend) # ======================================================================= def blend_with_variance_fix(a, b, mask): """ Mathematically correct blending of latent noise. βœ… FLOAT16 FIX: Safe sqrt with adaptive epsilon Args: a: Primary layer (active where mask=1) -> Advanced/Circular b: Background layer (active where mask=0) -> Simple/Replicate mask: Blend mask [0,1] πŸ”΄ IMPORTANT: From v11+ the mask semantics are: mask=1.0 on EDGES (where Advanced padding is needed) mask=0.0 in CENTER (where content or Simple padding is) """ # 1. Linear blend blended = a * mask + b * (1 - mask) # 2. Variance correction with adaptive epsilon eps_val = get_safe_epsilon(mask.dtype) variance_fix = torch.sqrt(mask**2 + (1 - mask)**2 + eps_val) return blended / variance_fix # ======================================================================= # 4. LEGACY FADE TO BLACK (For ZOOM effect) βœ… # ======================================================================= def compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength=0.1): """ ⚑ LEGACY MODE for Zoom effect (V3.5 logic from v11-v13) ⚑ Gradient now covers ENTIRE padding + part of content. Result: Beautiful vignette from 0 (edge) to 1 (center). πŸ”΄ vUltimate Note: This is v13 logic (NOT v7 logic). v7 applied fade only to content inside padding zones. v13 applies fade to ENTIRE image including padding. Args: padded: Already padded tensor [B, C, H, W] pad_h: Vertical padding size pad_w: Horizontal padding size fade_strength: Fade depth into content (0.0-1.0), typically 0.05-0.2 Returns: Tensor with darkened edges """ b, c, H, W = padded.shape # Calculate content size (without padding) h_content = max(H - 2 * pad_h, 0) w_content = max(W - 2 * pad_w, 0) # Fade depth inside content blend_in_h = int(h_content * fade_strength) blend_in_w = int(w_content * fade_strength) # Total fade zone = Padding + Entry into content total_fade_h = pad_h + blend_in_h total_fade_w = pad_w + blend_in_w result = padded.clone() # ═══════════════════════════════════════════════════════════════════ # VERTICAL EDGES # ═══════════════════════════════════════════════════════════════════ if total_fade_h > 0: # Create gradient 0 -> 1 fade = torch.linspace(0, 1, steps=total_fade_h, device=padded.device, dtype=padded.dtype) fade = fade.view(1, 1, -1, 1) # Shape: (1,1,H,1) # Top (from 0 to total_fade_h) safe_h = min(total_fade_h, H) result[:, :, :safe_h, :] *= fade[:, :, :safe_h, :] # Bottom (from H-total_fade_h to H) - use flipped gradient result[:, :, -safe_h:, :] *= fade[:, :, :safe_h, :].flip(2) # ═══════════════════════════════════════════════════════════════════ # HORIZONTAL EDGES # ═══════════════════════════════════════════════════════════════════ if total_fade_w > 0: # Create gradient 0 -> 1 fade = torch.linspace(0, 1, steps=total_fade_w, device=padded.device, dtype=padded.dtype) fade = fade.view(1, 1, 1, -1) # Shape: (1,1,1,W) # Left safe_w = min(total_fade_w, W) result[:, :, :, :safe_w] *= fade[:, :, :, :safe_w] # Right result[:, :, :, -safe_w:] *= fade[:, :, :, :safe_w].flip(3) return result # ======================================================================= # 5. ADVANCED BLEND MASK (For modern tiling mode) # ======================================================================= def create_advanced_blend_mask(h, w, blend_width, device, dtype=torch.float32, falloff_curve="smoothstep", edge_sharpness=1.0): """ Creates cached edge blend mask. πŸ”΄ MASK SEMANTICS (v11+ convention): 1.0 = on the very EDGE (where Advanced Padding is needed) 0.0 = in CENTER (where content or Simple Padding is) Args: h, w: Mask dimensions blend_width: Transition zone width (pixels) device: Torch device dtype: Data type falloff_curve: Curve type ('linear', 'smoothstep', 'cosine') edge_sharpness: Edge sharpness (1.0 = normal, >1 = sharper, <1 = softer) Returns: Mask of size [1, 1, h, w] """ if blend_width <= 0: return torch.ones(1, 1, h, w, device=device, dtype=dtype) # BUG FIX 6a: normalise falloff_curve to a known value; warn loudly if # the UI has sent something the backend doesn't actually implement. _KNOWN_FALLOFFS = {'linear', 'smoothstep', 'cosine'} if falloff_curve not in _KNOWN_FALLOFFS: print(f"[AdvancedBlend] Warning: unsupported falloff_curve '{falloff_curve}' " f"β€” falling back to 'smoothstep'. Supported: {sorted(_KNOWN_FALLOFFS)}") falloff_curve = 'smoothstep' blend_w = min(blend_width, w // 2) blend_h = min(blend_width, h // 2) mask = torch.zeros((1, 1, h, w), device=device, dtype=dtype) def get_ramp(size): """Generate gradient with configurable curve. BUG FIX: size==1 Ρ‡Π΅Ρ€Π΅Π· linspace(0,1,1) Π΄Π°Π²Π°Π» [0], Ρ‚ΠΎ Π΅ΡΡ‚ΡŒ Π½ΡƒΠ»Π΅Π²ΡƒΡŽ маску. Π’Π΅ΠΏΠ΅Ρ€ΡŒ для size<=1 Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ones β€” Π³Ρ€Π°Π½ΠΈΡ‡Π½Ρ‹ΠΉ пиксСль ΠΏΠΎΠ»ΡƒΡ‡Π°Π΅Ρ‚ ΠΏΠΎΠ»Π½Ρ‹ΠΉ вСс. """ if size <= 1: return torch.ones(max(size, 1), device=device, dtype=dtype) t = torch.linspace(0, 1, steps=size, device=device, dtype=dtype) if edge_sharpness != 1.0: t = torch.pow(t, edge_sharpness) if falloff_curve == 'smoothstep': return t * t * (3 - 2 * t) elif falloff_curve == 'cosine': return (1 - torch.cos(t * math.pi)) / 2 elif falloff_curve == 'linear': return t return t # Fill edges if blend_w > 0: ramp = get_ramp(blend_w) # Left edge mask[:, :, :, :blend_w] = torch.maximum(mask[:, :, :, :blend_w], ramp.flip(0).view(1,1,1,-1)) # Right edge mask[:, :, :, -blend_w:] = torch.maximum(mask[:, :, :, -blend_w:], ramp.view(1,1,1,-1)) if blend_h > 0: ramp = get_ramp(blend_h) # Top edge mask[:, :, :blend_h, :] = torch.maximum(mask[:, :, :blend_h, :], ramp.flip(0).view(1,1,-1,1)) # Bottom edge mask[:, :, -blend_h:, :] = torch.maximum(mask[:, :, -blend_h:, :], ramp.view(1,1,-1,1)) return mask # ======================================================================= # 6. IMPROVED BLEND PADDING (Main tiling function) # ======================================================================= def compute_advanced_blend_padding(input_tensor, pad_h, pad_w, mode_simple='replicate', mode_advanced='circular', blend_strength=0.5, blend_width=None, falloff_curve='smoothstep', edge_sharpness=1.0, fade_to_black=False, fade_strength=0.1): """ IMPROVED PADDING MODE Two operation modes: 1. FADE TO BLACK (fade_to_black=True) - for Zoom effect: - Applies one padding (mode_advanced) - DARKENS edges, creating zoom out effect - Uses legacy compute_blend_fade_to_black function 2. BLEND TWO PADDINGS (fade_to_black=False) - for quality edges: - Creates two different paddings (simple and advanced) - Blends them via mask - Applies variance fix for color correction - Does NOT create zoom effect Args: input_tensor: Original tensor WITHOUT padding [B, C, H, W] pad_h, pad_w: Padding sizes mode_simple: Mode for "simple" padding ('replicate', 'constant') mode_advanced: Mode for "advanced" padding ('circular', 'reflect') blend_strength: Blend strength (0.0-1.0) blend_width: Transition width (None = auto) falloff_curve: Gradient curve type edge_sharpness: Edge sharpness fade_to_black: If True, uses legacy darkening mode fade_strength: Darkening strength for fade_to_black mode Returns: Padded tensor [B, C, H+2*pad_h, W+2*pad_w] """ # ═══════════════════════════════════════════════════════════════════ # MODE 1: FADE TO BLACK (for Zoom) # ═══════════════════════════════════════════════════════════════════ if fade_to_black: # Apply ONE padding first if isinstance(mode_advanced, str): if mode_advanced == 'reflect': b, c, h, w = input_tensor.shape if pad_w < w and pad_h < h: padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') else: padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') else: # circular padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced) else: padded = mode_advanced # Pre-computed tensor # Darken edges (now works correctly!) return compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength) # ═══════════════════════════════════════════════════════════════════ # MODE 2: BLEND TWO PADDINGS (Tiling) # ═══════════════════════════════════════════════════════════════════ # If disabled if blend_strength <= 0.001: if mode_simple == 'constant': return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0) return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple) # If 100% strength if blend_strength >= 0.999: if isinstance(mode_advanced, str): if mode_advanced == 'reflect': b, c, h, w = input_tensor.shape if pad_w < w and pad_h < h: return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') else: return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced) return mode_advanced # Pre-computed tensor # 1. Prepare layers if mode_simple == 'constant': simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0) else: simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple) if isinstance(mode_advanced, str): if mode_advanced == 'reflect': b, c, h, w = input_tensor.shape if pad_w < w and pad_h < h: advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') else: advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') else: # circular advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') else: advanced = mode_advanced # Pre-computed # 2. Get mask from cache if blend_width is None: blend_width = max(pad_h, pad_w) b, c, h, w = input_tensor.shape device = input_tensor.device dtype = input_tensor.dtype # πŸ”΄ vUltimate Fix: Enhanced cache key WITH dtype (v11+ feature) cache_key = (h, w, pad_h, pad_w, blend_width, falloff_curve, edge_sharpness, str(device), str(dtype)) mask = _MASK_CACHE.get(cache_key) if mask is None: H_pad, W_pad = simple.shape[2:] mask = create_advanced_blend_mask(H_pad, W_pad, blend_width, device, dtype, falloff_curve, edge_sharpness) _MASK_CACHE.set(cache_key, mask) # 3. Blend with variance fix final_mask = mask * blend_strength # πŸ”΄ vUltimate: v11+ semantics: advanced (mask=1 on edges) / simple (mask=0 in center) return blend_with_variance_fix(advanced, simple, final_mask) # ======================================================================= # 7. MULTI-RESOLUTION (Temporal strategy) # ======================================================================= class BlendStrategy: """Interpolation strategies for multi-resolution transitions""" LINEAR = "linear" COSINE = "cosine" EXPONENTIAL = "exponential" SIGMOID = "sigmoid" class MultiResStrategy: """Handles temporal blending curves for progressive detail addition""" def __init__(self, strategy_type=BlendStrategy.COSINE): self.strategy_type = strategy_type def get_factor(self, progress, sharpness=1.0): """Calculate blend factor based on progress (0.0 to 1.0)""" t = max(0.0, min(1.0, progress)) if sharpness != 1.0: t = math.pow(t, sharpness) if self.strategy_type == BlendStrategy.LINEAR: return t elif self.strategy_type == BlendStrategy.COSINE: return (1.0 - math.cos(t * math.pi)) / 2.0 elif self.strategy_type == BlendStrategy.EXPONENTIAL: return math.pow(t, 2) elif self.strategy_type == BlendStrategy.SIGMOID: if t <= 0: return 0.0 if t >= 1: return 1.0 return 1.0 / (1.0 + math.exp(-12.0 * (t - 0.5))) return t def apply_multires_blend(tensor_simple, tensor_advanced, current_step, start_step, end_step, strategy="cosine", transition_start=0.0, transition_end=0.3, sharpness=1.0, enabled=False): """ Progressive blending from simple to advanced over denoising steps. Args: tensor_simple: Low-detail padding result tensor_advanced: High-detail padding result current_step: Current denoising step start_step, end_step: Denoising range strategy: Interpolation curve type transition_start, transition_end: Transition window (0.0-1.0) sharpness: Curve adjustment enabled: Master switch Returns: Blended tensor """ if not enabled: return tensor_advanced # BUG FIX 6b: normalise strategy to a supported value before use. _KNOWN_STRATEGIES = {BlendStrategy.LINEAR, BlendStrategy.COSINE, BlendStrategy.EXPONENTIAL, BlendStrategy.SIGMOID} _STRATEGY_ALIASES = { 'linear': BlendStrategy.LINEAR, 'cosine': BlendStrategy.COSINE, 'exponential': BlendStrategy.EXPONENTIAL, 'sigmoid': BlendStrategy.SIGMOID, } if isinstance(strategy, str): strategy_key = strategy.lower() if strategy_key not in _STRATEGY_ALIASES: print(f"[MultiRes] Warning: unsupported strategy '{strategy}' " f"β€” falling back to 'cosine'. " f"Supported: {sorted(_STRATEGY_ALIASES.keys())}") strategy = BlendStrategy.COSINE else: strategy = _STRATEGY_ALIASES[strategy_key] elif strategy not in _KNOWN_STRATEGIES: print(f"[MultiRes] Warning: unknown strategy {strategy!r} β€” falling back to cosine") strategy = BlendStrategy.COSINE total_steps = end_step - start_step if total_steps <= 0: return tensor_advanced step_frac = (current_step - start_step) / total_steps step_frac = max(0.0, min(1.0, step_frac)) if step_frac < transition_start: local_progress = 0.0 elif step_frac > transition_end: local_progress = 1.0 else: duration = transition_end - transition_start if duration <= 0: local_progress = 1.0 else: local_progress = (step_frac - transition_start) / duration strat = MultiResStrategy(strategy) alpha = strat.get_factor(local_progress, sharpness) if alpha <= 0.001: return tensor_simple if alpha >= 0.999: return tensor_advanced # Standard lerp (temporal blend, not spatial) return tensor_simple * (1.0 - alpha) + tensor_advanced * alpha # ======================================================================= # 8. HELPER FUNCTIONS (From v13 for compatibility) # ======================================================================= def create_circular_mask(h, w, center_x=0.5, center_y=0.5, radius=0.5, device='cpu', dtype=torch.float32): """ Creates circular mask (white circle on black background). βœ… FLOAT16 FIX: Safe sqrt NOTE: This is v13 version (radial distance mask). Different from v1/exp which don't have this function. """ eps_val = get_safe_epsilon(dtype) # Create coordinate grid y, x = torch.meshgrid( torch.linspace(-1, 1, h, device=device, dtype=dtype), torch.linspace(-1, 1, w, device=device, dtype=dtype), indexing='ij' ) # Shift center x = x - (center_x - 0.5) * 2 y = y - (center_y - 0.5) * 2 # Calculate distance from center with protection dist = torch.sqrt(x*x + y*y + eps_val) # Create soft mask (smooth edges 0.1) mask = 1.0 - torch.clamp((dist - (radius - 0.1)) / 0.2, 0, 1) # Add channel and batch dimensions if len(mask.shape) == 2: mask = mask.unsqueeze(0).unsqueeze(0) return mask def create_fade_to_black_mask(h, w, strength=0.1, device='cpu', dtype=torch.float32): """ Creates vignette (darkening towards edges). βœ… FLOAT16 FIX: Safe sqrt NOTE: This is v13 version (radial vignette). Different from v1/exp which don't have this function. """ eps_val = get_safe_epsilon(dtype) y, x = torch.meshgrid( torch.linspace(-1, 1, h, device=device, dtype=dtype), torch.linspace(-1, 1, w, device=device, dtype=dtype), indexing='ij' ) # sqrt with protection dist = torch.sqrt(x*x + y*y + eps_val) # Normalize so corners are 1.0 (max distance ~1.41) dist = dist / 1.4142 # Invert: center white (1), edges black (0) threshold = 1.0 - strength mask = 1.0 - torch.clamp((dist - threshold) / strength, 0, 1) if len(mask.shape) == 2: mask = mask.unsqueeze(0).unsqueeze(0) return mask # ======================================================================= # 9. PARAMETER VALIDATION HELPERS # ======================================================================= def validate_blend_params(params): """Extract and validate blend parameters from dict. Unsupported falloff values are normalised here with a warning so callers never silently receive a mode the backend cannot honour. Supported: 'linear', 'smoothstep', 'cosine' """ falloff = params.get('blend_falloff', 'smoothstep') _SUPPORTED_FALLOFFS = {'linear', 'smoothstep', 'cosine'} if falloff not in _SUPPORTED_FALLOFFS: print(f"[validate_blend_params] Warning: unsupported blend_falloff '{falloff}' " f"β€” falling back to 'smoothstep'. Supported: {sorted(_SUPPORTED_FALLOFFS)}") falloff = 'smoothstep' return { 'strength': float(params.get('blend_strength', 0.5)), 'width': int(params.get('blend_width', 0)) if params.get('blend_width') else None, 'falloff': falloff, 'sharpness': float(params.get('blend_sharpness', 1.0)), 'fade_to_black': bool(params.get('blend_fade_to_black', False)), 'fade_strength': float(params.get('blend_fade_strength', 0.1)) } def validate_multires_params(params): """Extract and validate multi-resolution parameters from dict. Unsupported strategy values are normalised here with a warning. Supported: 'linear', 'cosine', 'exponential', 'sigmoid' """ strategy = params.get('multires_strategy', 'cosine') _SUPPORTED_STRATEGIES = {'linear', 'cosine', 'exponential', 'sigmoid'} if strategy not in _SUPPORTED_STRATEGIES: print(f"[validate_multires_params] Warning: unsupported multires_strategy '{strategy}' " f"β€” falling back to 'cosine'. Supported: {sorted(_SUPPORTED_STRATEGIES)}") strategy = 'cosine' return { 'strategy': strategy, 'transition_start': float(params.get('multires_start', 0.0)), 'transition_end': float(params.get('multires_end', 0.3)), 'sharpness': float(params.get('multires_sharpness', 1.0)) } # ======================================================================= # vUltimate - End of File # CRITICAL FIXES APPLIED: # βœ… Fix 1: Infinite recursion in get_safe_epsilon (v13 bug at line 51/65) # βœ… Fix 2: Correct v11+ mask semantics (advanced first, mask=1.0 on edges) # βœ… Fix 3: Cache key includes dtype (v11+ improvement) # βœ… Fix 4: Safe reflect with size validation # βœ… Fix 5: v13 fade_to_black logic (not v7 logic) # βœ… Fix 6: v13 circular/vignette masks (not v1/exp) # =======================================================================