| import torch |
| import torch.nn.functional as F |
| import math |
| import numpy as np |
| from collections import OrderedDict |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| class SmartMaskCache: |
| """LRU cache for blend masks to avoid regeneration""" |
| def __init__(self, max_size=50): |
| self.cache = OrderedDict() |
| self.max_size = max_size |
|
|
| def get(self, key): |
| if key in self.cache: |
| self.cache.move_to_end(key) |
| return self.cache[key] |
| return None |
|
|
| def set(self, key, value): |
| if key in self.cache: |
| self.cache.move_to_end(key) |
| self.cache[key] = value |
| if len(self.cache) > self.max_size: |
| self.cache.popitem(last=False) |
|
|
| |
| _MASK_CACHE = SmartMaskCache() |
|
|
|
|
| |
| |
| |
| def get_safe_epsilon(tensor_or_dtype): |
| """ |
| Float16-safe epsilon - CRITICAL for half precision! |
| |
| π΄ vUltimate Fix: v13 had INFINITE RECURSION bug: |
| Line 51: return get_safe_epsilon(torch.float16) # INFINITE LOOP! |
| |
| Args: |
| tensor_or_dtype: torch.Tensor or torch.dtype |
| |
| Returns: |
| float: safe epsilon for given dtype |
| """ |
| if isinstance(tensor_or_dtype, torch.Tensor): |
| dtype = tensor_or_dtype.dtype |
| else: |
| dtype = tensor_or_dtype |
| |
| |
| if dtype in (torch.float16, torch.bfloat16): |
| return 1e-3 |
| elif dtype == torch.float32: |
| return 1e-6 |
| else: |
| return 1e-12 |
|
|
|
|
| |
| |
| |
| def blend_with_variance_fix(a, b, mask): |
| """ |
| Mathematically correct blending of latent noise. |
| β
FLOAT16 FIX: Safe sqrt with adaptive epsilon |
| |
| Args: |
| a: Primary layer (active where mask=1) -> Advanced/Circular |
| b: Background layer (active where mask=0) -> Simple/Replicate |
| mask: Blend mask [0,1] |
| |
| π΄ IMPORTANT: From v11+ the mask semantics are: |
| mask=1.0 on EDGES (where Advanced padding is needed) |
| mask=0.0 in CENTER (where content or Simple padding is) |
| """ |
| |
| blended = a * mask + b * (1 - mask) |
| |
| |
| eps_val = get_safe_epsilon(mask.dtype) |
| variance_fix = torch.sqrt(mask**2 + (1 - mask)**2 + eps_val) |
| |
| return blended / variance_fix |
|
|
|
|
| |
| |
| |
| def compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength=0.1): |
| """ |
| β‘ LEGACY MODE for Zoom effect (V3.5 logic from v11-v13) β‘ |
| |
| Gradient now covers ENTIRE padding + part of content. |
| Result: Beautiful vignette from 0 (edge) to 1 (center). |
| |
| π΄ vUltimate Note: This is v13 logic (NOT v7 logic). |
| v7 applied fade only to content inside padding zones. |
| v13 applies fade to ENTIRE image including padding. |
| |
| Args: |
| padded: Already padded tensor [B, C, H, W] |
| pad_h: Vertical padding size |
| pad_w: Horizontal padding size |
| fade_strength: Fade depth into content (0.0-1.0), typically 0.05-0.2 |
| |
| Returns: |
| Tensor with darkened edges |
| """ |
| b, c, H, W = padded.shape |
| |
| |
| h_content = max(H - 2 * pad_h, 0) |
| w_content = max(W - 2 * pad_w, 0) |
| |
| |
| blend_in_h = int(h_content * fade_strength) |
| blend_in_w = int(w_content * fade_strength) |
| |
| |
| total_fade_h = pad_h + blend_in_h |
| total_fade_w = pad_w + blend_in_w |
| |
| result = padded.clone() |
| |
| |
| |
| |
| if total_fade_h > 0: |
| |
| fade = torch.linspace(0, 1, steps=total_fade_h, |
| device=padded.device, dtype=padded.dtype) |
| fade = fade.view(1, 1, -1, 1) |
| |
| |
| safe_h = min(total_fade_h, H) |
| result[:, :, :safe_h, :] *= fade[:, :, :safe_h, :] |
| |
| |
| result[:, :, -safe_h:, :] *= fade[:, :, :safe_h, :].flip(2) |
|
|
| |
| |
| |
| if total_fade_w > 0: |
| |
| fade = torch.linspace(0, 1, steps=total_fade_w, |
| device=padded.device, dtype=padded.dtype) |
| fade = fade.view(1, 1, 1, -1) |
| |
| |
| safe_w = min(total_fade_w, W) |
| result[:, :, :, :safe_w] *= fade[:, :, :, :safe_w] |
| |
| |
| result[:, :, :, -safe_w:] *= fade[:, :, :, :safe_w].flip(3) |
|
|
| return result |
|
|
|
|
| |
| |
| |
| def create_advanced_blend_mask(h, w, blend_width, device, dtype=torch.float32, |
| falloff_curve="smoothstep", edge_sharpness=1.0): |
| """ |
| Creates cached edge blend mask. |
| |
| π΄ MASK SEMANTICS (v11+ convention): |
| 1.0 = on the very EDGE (where Advanced Padding is needed) |
| 0.0 = in CENTER (where content or Simple Padding is) |
| |
| Args: |
| h, w: Mask dimensions |
| blend_width: Transition zone width (pixels) |
| device: Torch device |
| dtype: Data type |
| falloff_curve: Curve type ('linear', 'smoothstep', 'cosine') |
| edge_sharpness: Edge sharpness (1.0 = normal, >1 = sharper, <1 = softer) |
| |
| Returns: |
| Mask of size [1, 1, h, w] |
| """ |
| if blend_width <= 0: |
| return torch.ones(1, 1, h, w, device=device, dtype=dtype) |
|
|
| |
| |
| _KNOWN_FALLOFFS = {'linear', 'smoothstep', 'cosine'} |
| if falloff_curve not in _KNOWN_FALLOFFS: |
| print(f"[AdvancedBlend] Warning: unsupported falloff_curve '{falloff_curve}' " |
| f"β falling back to 'smoothstep'. Supported: {sorted(_KNOWN_FALLOFFS)}") |
| falloff_curve = 'smoothstep' |
|
|
| blend_w = min(blend_width, w // 2) |
| blend_h = min(blend_width, h // 2) |
| |
| mask = torch.zeros((1, 1, h, w), device=device, dtype=dtype) |
| |
| def get_ramp(size): |
| """Generate gradient with configurable curve. |
| BUG FIX: size==1 ΡΠ΅ΡΠ΅Π· linspace(0,1,1) Π΄Π°Π²Π°Π» [0], ΡΠΎ Π΅ΡΡΡ Π½ΡΠ»Π΅Π²ΡΡ ΠΌΠ°ΡΠΊΡ. |
| Π’Π΅ΠΏΠ΅ΡΡ Π΄Π»Ρ size<=1 Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ ones β Π³ΡΠ°Π½ΠΈΡΠ½ΡΠΉ ΠΏΠΈΠΊΡΠ΅Π»Ρ ΠΏΠΎΠ»ΡΡΠ°Π΅Ρ ΠΏΠΎΠ»Π½ΡΠΉ Π²Π΅Ρ. |
| """ |
| if size <= 1: |
| return torch.ones(max(size, 1), device=device, dtype=dtype) |
| t = torch.linspace(0, 1, steps=size, device=device, dtype=dtype) |
| if edge_sharpness != 1.0: |
| t = torch.pow(t, edge_sharpness) |
| |
| if falloff_curve == 'smoothstep': |
| return t * t * (3 - 2 * t) |
| elif falloff_curve == 'cosine': |
| return (1 - torch.cos(t * math.pi)) / 2 |
| elif falloff_curve == 'linear': |
| return t |
| return t |
|
|
| |
| if blend_w > 0: |
| ramp = get_ramp(blend_w) |
| |
| mask[:, :, :, :blend_w] = torch.maximum(mask[:, :, :, :blend_w], |
| ramp.flip(0).view(1,1,1,-1)) |
| |
| mask[:, :, :, -blend_w:] = torch.maximum(mask[:, :, :, -blend_w:], |
| ramp.view(1,1,1,-1)) |
|
|
| if blend_h > 0: |
| ramp = get_ramp(blend_h) |
| |
| mask[:, :, :blend_h, :] = torch.maximum(mask[:, :, :blend_h, :], |
| ramp.flip(0).view(1,1,-1,1)) |
| |
| mask[:, :, -blend_h:, :] = torch.maximum(mask[:, :, -blend_h:, :], |
| ramp.view(1,1,-1,1)) |
| |
| return mask |
|
|
|
|
| |
| |
| |
| def compute_advanced_blend_padding(input_tensor, pad_h, pad_w, |
| mode_simple='replicate', |
| mode_advanced='circular', |
| blend_strength=0.5, |
| blend_width=None, |
| falloff_curve='smoothstep', |
| edge_sharpness=1.0, |
| fade_to_black=False, |
| fade_strength=0.1): |
| """ |
| IMPROVED PADDING MODE |
| |
| Two operation modes: |
| |
| 1. FADE TO BLACK (fade_to_black=True) - for Zoom effect: |
| - Applies one padding (mode_advanced) |
| - DARKENS edges, creating zoom out effect |
| - Uses legacy compute_blend_fade_to_black function |
| |
| 2. BLEND TWO PADDINGS (fade_to_black=False) - for quality edges: |
| - Creates two different paddings (simple and advanced) |
| - Blends them via mask |
| - Applies variance fix for color correction |
| - Does NOT create zoom effect |
| |
| Args: |
| input_tensor: Original tensor WITHOUT padding [B, C, H, W] |
| pad_h, pad_w: Padding sizes |
| mode_simple: Mode for "simple" padding ('replicate', 'constant') |
| mode_advanced: Mode for "advanced" padding ('circular', 'reflect') |
| blend_strength: Blend strength (0.0-1.0) |
| blend_width: Transition width (None = auto) |
| falloff_curve: Gradient curve type |
| edge_sharpness: Edge sharpness |
| fade_to_black: If True, uses legacy darkening mode |
| fade_strength: Darkening strength for fade_to_black mode |
| |
| Returns: |
| Padded tensor [B, C, H+2*pad_h, W+2*pad_w] |
| """ |
| |
| |
| |
| |
| if fade_to_black: |
| |
| if isinstance(mode_advanced, str): |
| if mode_advanced == 'reflect': |
| b, c, h, w = input_tensor.shape |
| if pad_w < w and pad_h < h: |
| padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') |
| else: |
| padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') |
| else: |
| padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced) |
| else: |
| padded = mode_advanced |
| |
| |
| return compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength) |
| |
| |
| |
| |
| |
| |
| if blend_strength <= 0.001: |
| if mode_simple == 'constant': |
| return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0) |
| return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple) |
|
|
| |
| if blend_strength >= 0.999: |
| if isinstance(mode_advanced, str): |
| if mode_advanced == 'reflect': |
| b, c, h, w = input_tensor.shape |
| if pad_w < w and pad_h < h: |
| return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') |
| else: |
| return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') |
| return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced) |
| return mode_advanced |
|
|
| |
| if mode_simple == 'constant': |
| simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0) |
| else: |
| simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple) |
| |
| if isinstance(mode_advanced, str): |
| if mode_advanced == 'reflect': |
| b, c, h, w = input_tensor.shape |
| if pad_w < w and pad_h < h: |
| advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect') |
| else: |
| advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate') |
| else: |
| advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular') |
| else: |
| advanced = mode_advanced |
|
|
| |
| if blend_width is None: |
| blend_width = max(pad_h, pad_w) |
| |
| b, c, h, w = input_tensor.shape |
| device = input_tensor.device |
| dtype = input_tensor.dtype |
| |
| |
| cache_key = (h, w, pad_h, pad_w, blend_width, falloff_curve, edge_sharpness, |
| str(device), str(dtype)) |
| mask = _MASK_CACHE.get(cache_key) |
| |
| if mask is None: |
| H_pad, W_pad = simple.shape[2:] |
| mask = create_advanced_blend_mask(H_pad, W_pad, blend_width, device, |
| dtype, falloff_curve, edge_sharpness) |
| _MASK_CACHE.set(cache_key, mask) |
|
|
| |
| final_mask = mask * blend_strength |
| |
| |
| return blend_with_variance_fix(advanced, simple, final_mask) |
|
|
|
|
| |
| |
| |
| class BlendStrategy: |
| """Interpolation strategies for multi-resolution transitions""" |
| LINEAR = "linear" |
| COSINE = "cosine" |
| EXPONENTIAL = "exponential" |
| SIGMOID = "sigmoid" |
|
|
| class MultiResStrategy: |
| """Handles temporal blending curves for progressive detail addition""" |
| def __init__(self, strategy_type=BlendStrategy.COSINE): |
| self.strategy_type = strategy_type |
|
|
| def get_factor(self, progress, sharpness=1.0): |
| """Calculate blend factor based on progress (0.0 to 1.0)""" |
| t = max(0.0, min(1.0, progress)) |
| if sharpness != 1.0: |
| t = math.pow(t, sharpness) |
|
|
| if self.strategy_type == BlendStrategy.LINEAR: |
| return t |
| elif self.strategy_type == BlendStrategy.COSINE: |
| return (1.0 - math.cos(t * math.pi)) / 2.0 |
| elif self.strategy_type == BlendStrategy.EXPONENTIAL: |
| return math.pow(t, 2) |
| elif self.strategy_type == BlendStrategy.SIGMOID: |
| if t <= 0: return 0.0 |
| if t >= 1: return 1.0 |
| return 1.0 / (1.0 + math.exp(-12.0 * (t - 0.5))) |
| return t |
|
|
| def apply_multires_blend(tensor_simple, tensor_advanced, current_step, |
| start_step, end_step, |
| strategy="cosine", |
| transition_start=0.0, |
| transition_end=0.3, |
| sharpness=1.0, |
| enabled=False): |
| """ |
| Progressive blending from simple to advanced over denoising steps. |
| |
| Args: |
| tensor_simple: Low-detail padding result |
| tensor_advanced: High-detail padding result |
| current_step: Current denoising step |
| start_step, end_step: Denoising range |
| strategy: Interpolation curve type |
| transition_start, transition_end: Transition window (0.0-1.0) |
| sharpness: Curve adjustment |
| enabled: Master switch |
| |
| Returns: |
| Blended tensor |
| """ |
| if not enabled: |
| return tensor_advanced |
|
|
| |
| _KNOWN_STRATEGIES = {BlendStrategy.LINEAR, BlendStrategy.COSINE, |
| BlendStrategy.EXPONENTIAL, BlendStrategy.SIGMOID} |
| _STRATEGY_ALIASES = { |
| 'linear': BlendStrategy.LINEAR, |
| 'cosine': BlendStrategy.COSINE, |
| 'exponential': BlendStrategy.EXPONENTIAL, |
| 'sigmoid': BlendStrategy.SIGMOID, |
| } |
| if isinstance(strategy, str): |
| strategy_key = strategy.lower() |
| if strategy_key not in _STRATEGY_ALIASES: |
| print(f"[MultiRes] Warning: unsupported strategy '{strategy}' " |
| f"β falling back to 'cosine'. " |
| f"Supported: {sorted(_STRATEGY_ALIASES.keys())}") |
| strategy = BlendStrategy.COSINE |
| else: |
| strategy = _STRATEGY_ALIASES[strategy_key] |
| elif strategy not in _KNOWN_STRATEGIES: |
| print(f"[MultiRes] Warning: unknown strategy {strategy!r} β falling back to cosine") |
| strategy = BlendStrategy.COSINE |
|
|
| total_steps = end_step - start_step |
| if total_steps <= 0: |
| return tensor_advanced |
|
|
| step_frac = (current_step - start_step) / total_steps |
| step_frac = max(0.0, min(1.0, step_frac)) |
| |
| if step_frac < transition_start: |
| local_progress = 0.0 |
| elif step_frac > transition_end: |
| local_progress = 1.0 |
| else: |
| duration = transition_end - transition_start |
| if duration <= 0: |
| local_progress = 1.0 |
| else: |
| local_progress = (step_frac - transition_start) / duration |
|
|
| strat = MultiResStrategy(strategy) |
| alpha = strat.get_factor(local_progress, sharpness) |
|
|
| if alpha <= 0.001: |
| return tensor_simple |
| if alpha >= 0.999: |
| return tensor_advanced |
|
|
| |
| return tensor_simple * (1.0 - alpha) + tensor_advanced * alpha |
|
|
|
|
| |
| |
| |
|
|
| def create_circular_mask(h, w, center_x=0.5, center_y=0.5, radius=0.5, |
| device='cpu', dtype=torch.float32): |
| """ |
| Creates circular mask (white circle on black background). |
| β
FLOAT16 FIX: Safe sqrt |
| |
| NOTE: This is v13 version (radial distance mask). |
| Different from v1/exp which don't have this function. |
| """ |
| eps_val = get_safe_epsilon(dtype) |
| |
| |
| y, x = torch.meshgrid( |
| torch.linspace(-1, 1, h, device=device, dtype=dtype), |
| torch.linspace(-1, 1, w, device=device, dtype=dtype), |
| indexing='ij' |
| ) |
| |
| |
| x = x - (center_x - 0.5) * 2 |
| y = y - (center_y - 0.5) * 2 |
| |
| |
| dist = torch.sqrt(x*x + y*y + eps_val) |
| |
| |
| mask = 1.0 - torch.clamp((dist - (radius - 0.1)) / 0.2, 0, 1) |
| |
| |
| if len(mask.shape) == 2: |
| mask = mask.unsqueeze(0).unsqueeze(0) |
| |
| return mask |
|
|
| def create_fade_to_black_mask(h, w, strength=0.1, device='cpu', dtype=torch.float32): |
| """ |
| Creates vignette (darkening towards edges). |
| β
FLOAT16 FIX: Safe sqrt |
| |
| NOTE: This is v13 version (radial vignette). |
| Different from v1/exp which don't have this function. |
| """ |
| eps_val = get_safe_epsilon(dtype) |
| |
| y, x = torch.meshgrid( |
| torch.linspace(-1, 1, h, device=device, dtype=dtype), |
| torch.linspace(-1, 1, w, device=device, dtype=dtype), |
| indexing='ij' |
| ) |
| |
| |
| dist = torch.sqrt(x*x + y*y + eps_val) |
| |
| |
| dist = dist / 1.4142 |
| |
| |
| threshold = 1.0 - strength |
| mask = 1.0 - torch.clamp((dist - threshold) / strength, 0, 1) |
| |
| if len(mask.shape) == 2: |
| mask = mask.unsqueeze(0).unsqueeze(0) |
| |
| return mask |
|
|
|
|
| |
| |
| |
|
|
| def validate_blend_params(params): |
| """Extract and validate blend parameters from dict. |
| Unsupported falloff values are normalised here with a warning so callers |
| never silently receive a mode the backend cannot honour. |
| Supported: 'linear', 'smoothstep', 'cosine' |
| """ |
| falloff = params.get('blend_falloff', 'smoothstep') |
| _SUPPORTED_FALLOFFS = {'linear', 'smoothstep', 'cosine'} |
| if falloff not in _SUPPORTED_FALLOFFS: |
| print(f"[validate_blend_params] Warning: unsupported blend_falloff '{falloff}' " |
| f"β falling back to 'smoothstep'. Supported: {sorted(_SUPPORTED_FALLOFFS)}") |
| falloff = 'smoothstep' |
|
|
| return { |
| 'strength': float(params.get('blend_strength', 0.5)), |
| 'width': int(params.get('blend_width', 0)) if params.get('blend_width') else None, |
| 'falloff': falloff, |
| 'sharpness': float(params.get('blend_sharpness', 1.0)), |
| 'fade_to_black': bool(params.get('blend_fade_to_black', False)), |
| 'fade_strength': float(params.get('blend_fade_strength', 0.1)) |
| } |
|
|
| def validate_multires_params(params): |
| """Extract and validate multi-resolution parameters from dict. |
| Unsupported strategy values are normalised here with a warning. |
| Supported: 'linear', 'cosine', 'exponential', 'sigmoid' |
| """ |
| strategy = params.get('multires_strategy', 'cosine') |
| _SUPPORTED_STRATEGIES = {'linear', 'cosine', 'exponential', 'sigmoid'} |
| if strategy not in _SUPPORTED_STRATEGIES: |
| print(f"[validate_multires_params] Warning: unsupported multires_strategy '{strategy}' " |
| f"β falling back to 'cosine'. Supported: {sorted(_SUPPORTED_STRATEGIES)}") |
| strategy = 'cosine' |
|
|
| return { |
| 'strategy': strategy, |
| 'transition_start': float(params.get('multires_start', 0.0)), |
| 'transition_end': float(params.get('multires_end', 0.3)), |
| 'sharpness': float(params.get('multires_sharpness', 1.0)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|