sdas / asd /libs /improved_tiling_functions.py
dikdimon's picture
Upload improved_tiling_functions.py
49b7391 verified
import torch
import torch.nn.functional as F
import math
import numpy as np
from collections import OrderedDict
# =======================================================================
# vUltimate - REAL Deep Code Audit
# Based on v13 (THE_last_version) with CRITICAL FIXES
# =======================================================================
# =======================================================================
# 1. SMART CACHING (Speed optimization ~15-20%)
# =======================================================================
class SmartMaskCache:
"""LRU cache for blend masks to avoid regeneration"""
def __init__(self, max_size=50):
self.cache = OrderedDict()
self.max_size = max_size
def get(self, key):
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None
def set(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)
# Global cache instance
_MASK_CACHE = SmartMaskCache()
# =======================================================================
# 2. SAFE EPSILON (πŸ”΄ CRITICAL FIX: Infinite recursion bug in v13!)
# =======================================================================
def get_safe_epsilon(tensor_or_dtype):
"""
Float16-safe epsilon - CRITICAL for half precision!
πŸ”΄ vUltimate Fix: v13 had INFINITE RECURSION bug:
Line 51: return get_safe_epsilon(torch.float16) # INFINITE LOOP!
Args:
tensor_or_dtype: torch.Tensor or torch.dtype
Returns:
float: safe epsilon for given dtype
"""
if isinstance(tensor_or_dtype, torch.Tensor):
dtype = tensor_or_dtype.dtype
else:
dtype = tensor_or_dtype
# Float16 minimum value ~6e-5, so 1e-6 causes underflow
if dtype in (torch.float16, torch.bfloat16):
return 1e-3 # Safe for half precision
elif dtype == torch.float32:
return 1e-6 # πŸ”΄ FIX: Changed from recursive call to direct value
else:
return 1e-12 # High precision for float64
# =======================================================================
# 3. LATENT COLOR FIX (Variance-preserving blend)
# =======================================================================
def blend_with_variance_fix(a, b, mask):
"""
Mathematically correct blending of latent noise.
βœ… FLOAT16 FIX: Safe sqrt with adaptive epsilon
Args:
a: Primary layer (active where mask=1) -> Advanced/Circular
b: Background layer (active where mask=0) -> Simple/Replicate
mask: Blend mask [0,1]
πŸ”΄ IMPORTANT: From v11+ the mask semantics are:
mask=1.0 on EDGES (where Advanced padding is needed)
mask=0.0 in CENTER (where content or Simple padding is)
"""
# 1. Linear blend
blended = a * mask + b * (1 - mask)
# 2. Variance correction with adaptive epsilon
eps_val = get_safe_epsilon(mask.dtype)
variance_fix = torch.sqrt(mask**2 + (1 - mask)**2 + eps_val)
return blended / variance_fix
# =======================================================================
# 4. LEGACY FADE TO BLACK (For ZOOM effect) βœ…
# =======================================================================
def compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength=0.1):
"""
⚑ LEGACY MODE for Zoom effect (V3.5 logic from v11-v13) ⚑
Gradient now covers ENTIRE padding + part of content.
Result: Beautiful vignette from 0 (edge) to 1 (center).
πŸ”΄ vUltimate Note: This is v13 logic (NOT v7 logic).
v7 applied fade only to content inside padding zones.
v13 applies fade to ENTIRE image including padding.
Args:
padded: Already padded tensor [B, C, H, W]
pad_h: Vertical padding size
pad_w: Horizontal padding size
fade_strength: Fade depth into content (0.0-1.0), typically 0.05-0.2
Returns:
Tensor with darkened edges
"""
b, c, H, W = padded.shape
# Calculate content size (without padding)
h_content = max(H - 2 * pad_h, 0)
w_content = max(W - 2 * pad_w, 0)
# Fade depth inside content
blend_in_h = int(h_content * fade_strength)
blend_in_w = int(w_content * fade_strength)
# Total fade zone = Padding + Entry into content
total_fade_h = pad_h + blend_in_h
total_fade_w = pad_w + blend_in_w
result = padded.clone()
# ═══════════════════════════════════════════════════════════════════
# VERTICAL EDGES
# ═══════════════════════════════════════════════════════════════════
if total_fade_h > 0:
# Create gradient 0 -> 1
fade = torch.linspace(0, 1, steps=total_fade_h,
device=padded.device, dtype=padded.dtype)
fade = fade.view(1, 1, -1, 1) # Shape: (1,1,H,1)
# Top (from 0 to total_fade_h)
safe_h = min(total_fade_h, H)
result[:, :, :safe_h, :] *= fade[:, :, :safe_h, :]
# Bottom (from H-total_fade_h to H) - use flipped gradient
result[:, :, -safe_h:, :] *= fade[:, :, :safe_h, :].flip(2)
# ═══════════════════════════════════════════════════════════════════
# HORIZONTAL EDGES
# ═══════════════════════════════════════════════════════════════════
if total_fade_w > 0:
# Create gradient 0 -> 1
fade = torch.linspace(0, 1, steps=total_fade_w,
device=padded.device, dtype=padded.dtype)
fade = fade.view(1, 1, 1, -1) # Shape: (1,1,1,W)
# Left
safe_w = min(total_fade_w, W)
result[:, :, :, :safe_w] *= fade[:, :, :, :safe_w]
# Right
result[:, :, :, -safe_w:] *= fade[:, :, :, :safe_w].flip(3)
return result
# =======================================================================
# 5. ADVANCED BLEND MASK (For modern tiling mode)
# =======================================================================
def create_advanced_blend_mask(h, w, blend_width, device, dtype=torch.float32,
falloff_curve="smoothstep", edge_sharpness=1.0):
"""
Creates cached edge blend mask.
πŸ”΄ MASK SEMANTICS (v11+ convention):
1.0 = on the very EDGE (where Advanced Padding is needed)
0.0 = in CENTER (where content or Simple Padding is)
Args:
h, w: Mask dimensions
blend_width: Transition zone width (pixels)
device: Torch device
dtype: Data type
falloff_curve: Curve type ('linear', 'smoothstep', 'cosine')
edge_sharpness: Edge sharpness (1.0 = normal, >1 = sharper, <1 = softer)
Returns:
Mask of size [1, 1, h, w]
"""
if blend_width <= 0:
return torch.ones(1, 1, h, w, device=device, dtype=dtype)
# BUG FIX 6a: normalise falloff_curve to a known value; warn loudly if
# the UI has sent something the backend doesn't actually implement.
_KNOWN_FALLOFFS = {'linear', 'smoothstep', 'cosine'}
if falloff_curve not in _KNOWN_FALLOFFS:
print(f"[AdvancedBlend] Warning: unsupported falloff_curve '{falloff_curve}' "
f"β€” falling back to 'smoothstep'. Supported: {sorted(_KNOWN_FALLOFFS)}")
falloff_curve = 'smoothstep'
blend_w = min(blend_width, w // 2)
blend_h = min(blend_width, h // 2)
mask = torch.zeros((1, 1, h, w), device=device, dtype=dtype)
def get_ramp(size):
"""Generate gradient with configurable curve.
BUG FIX: size==1 Ρ‡Π΅Ρ€Π΅Π· linspace(0,1,1) Π΄Π°Π²Π°Π» [0], Ρ‚ΠΎ Π΅ΡΡ‚ΡŒ Π½ΡƒΠ»Π΅Π²ΡƒΡŽ маску.
Π’Π΅ΠΏΠ΅Ρ€ΡŒ для size<=1 Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ones β€” Π³Ρ€Π°Π½ΠΈΡ‡Π½Ρ‹ΠΉ пиксСль ΠΏΠΎΠ»ΡƒΡ‡Π°Π΅Ρ‚ ΠΏΠΎΠ»Π½Ρ‹ΠΉ вСс.
"""
if size <= 1:
return torch.ones(max(size, 1), device=device, dtype=dtype)
t = torch.linspace(0, 1, steps=size, device=device, dtype=dtype)
if edge_sharpness != 1.0:
t = torch.pow(t, edge_sharpness)
if falloff_curve == 'smoothstep':
return t * t * (3 - 2 * t)
elif falloff_curve == 'cosine':
return (1 - torch.cos(t * math.pi)) / 2
elif falloff_curve == 'linear':
return t
return t
# Fill edges
if blend_w > 0:
ramp = get_ramp(blend_w)
# Left edge
mask[:, :, :, :blend_w] = torch.maximum(mask[:, :, :, :blend_w],
ramp.flip(0).view(1,1,1,-1))
# Right edge
mask[:, :, :, -blend_w:] = torch.maximum(mask[:, :, :, -blend_w:],
ramp.view(1,1,1,-1))
if blend_h > 0:
ramp = get_ramp(blend_h)
# Top edge
mask[:, :, :blend_h, :] = torch.maximum(mask[:, :, :blend_h, :],
ramp.flip(0).view(1,1,-1,1))
# Bottom edge
mask[:, :, -blend_h:, :] = torch.maximum(mask[:, :, -blend_h:, :],
ramp.view(1,1,-1,1))
return mask
# =======================================================================
# 6. IMPROVED BLEND PADDING (Main tiling function)
# =======================================================================
def compute_advanced_blend_padding(input_tensor, pad_h, pad_w,
mode_simple='replicate',
mode_advanced='circular',
blend_strength=0.5,
blend_width=None,
falloff_curve='smoothstep',
edge_sharpness=1.0,
fade_to_black=False,
fade_strength=0.1):
"""
IMPROVED PADDING MODE
Two operation modes:
1. FADE TO BLACK (fade_to_black=True) - for Zoom effect:
- Applies one padding (mode_advanced)
- DARKENS edges, creating zoom out effect
- Uses legacy compute_blend_fade_to_black function
2. BLEND TWO PADDINGS (fade_to_black=False) - for quality edges:
- Creates two different paddings (simple and advanced)
- Blends them via mask
- Applies variance fix for color correction
- Does NOT create zoom effect
Args:
input_tensor: Original tensor WITHOUT padding [B, C, H, W]
pad_h, pad_w: Padding sizes
mode_simple: Mode for "simple" padding ('replicate', 'constant')
mode_advanced: Mode for "advanced" padding ('circular', 'reflect')
blend_strength: Blend strength (0.0-1.0)
blend_width: Transition width (None = auto)
falloff_curve: Gradient curve type
edge_sharpness: Edge sharpness
fade_to_black: If True, uses legacy darkening mode
fade_strength: Darkening strength for fade_to_black mode
Returns:
Padded tensor [B, C, H+2*pad_h, W+2*pad_w]
"""
# ═══════════════════════════════════════════════════════════════════
# MODE 1: FADE TO BLACK (for Zoom)
# ═══════════════════════════════════════════════════════════════════
if fade_to_black:
# Apply ONE padding first
if isinstance(mode_advanced, str):
if mode_advanced == 'reflect':
b, c, h, w = input_tensor.shape
if pad_w < w and pad_h < h:
padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
else:
padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
else: # circular
padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced)
else:
padded = mode_advanced # Pre-computed tensor
# Darken edges (now works correctly!)
return compute_blend_fade_to_black(padded, pad_h, pad_w, fade_strength)
# ═══════════════════════════════════════════════════════════════════
# MODE 2: BLEND TWO PADDINGS (Tiling)
# ═══════════════════════════════════════════════════════════════════
# If disabled
if blend_strength <= 0.001:
if mode_simple == 'constant':
return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)
return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple)
# If 100% strength
if blend_strength >= 0.999:
if isinstance(mode_advanced, str):
if mode_advanced == 'reflect':
b, c, h, w = input_tensor.shape
if pad_w < w and pad_h < h:
return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
else:
return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced)
return mode_advanced # Pre-computed tensor
# 1. Prepare layers
if mode_simple == 'constant':
simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)
else:
simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple)
if isinstance(mode_advanced, str):
if mode_advanced == 'reflect':
b, c, h, w = input_tensor.shape
if pad_w < w and pad_h < h:
advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
else:
advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='replicate')
else: # circular
advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular')
else:
advanced = mode_advanced # Pre-computed
# 2. Get mask from cache
if blend_width is None:
blend_width = max(pad_h, pad_w)
b, c, h, w = input_tensor.shape
device = input_tensor.device
dtype = input_tensor.dtype
# πŸ”΄ vUltimate Fix: Enhanced cache key WITH dtype (v11+ feature)
cache_key = (h, w, pad_h, pad_w, blend_width, falloff_curve, edge_sharpness,
str(device), str(dtype))
mask = _MASK_CACHE.get(cache_key)
if mask is None:
H_pad, W_pad = simple.shape[2:]
mask = create_advanced_blend_mask(H_pad, W_pad, blend_width, device,
dtype, falloff_curve, edge_sharpness)
_MASK_CACHE.set(cache_key, mask)
# 3. Blend with variance fix
final_mask = mask * blend_strength
# πŸ”΄ vUltimate: v11+ semantics: advanced (mask=1 on edges) / simple (mask=0 in center)
return blend_with_variance_fix(advanced, simple, final_mask)
# =======================================================================
# 7. MULTI-RESOLUTION (Temporal strategy)
# =======================================================================
class BlendStrategy:
"""Interpolation strategies for multi-resolution transitions"""
LINEAR = "linear"
COSINE = "cosine"
EXPONENTIAL = "exponential"
SIGMOID = "sigmoid"
class MultiResStrategy:
"""Handles temporal blending curves for progressive detail addition"""
def __init__(self, strategy_type=BlendStrategy.COSINE):
self.strategy_type = strategy_type
def get_factor(self, progress, sharpness=1.0):
"""Calculate blend factor based on progress (0.0 to 1.0)"""
t = max(0.0, min(1.0, progress))
if sharpness != 1.0:
t = math.pow(t, sharpness)
if self.strategy_type == BlendStrategy.LINEAR:
return t
elif self.strategy_type == BlendStrategy.COSINE:
return (1.0 - math.cos(t * math.pi)) / 2.0
elif self.strategy_type == BlendStrategy.EXPONENTIAL:
return math.pow(t, 2)
elif self.strategy_type == BlendStrategy.SIGMOID:
if t <= 0: return 0.0
if t >= 1: return 1.0
return 1.0 / (1.0 + math.exp(-12.0 * (t - 0.5)))
return t
def apply_multires_blend(tensor_simple, tensor_advanced, current_step,
start_step, end_step,
strategy="cosine",
transition_start=0.0,
transition_end=0.3,
sharpness=1.0,
enabled=False):
"""
Progressive blending from simple to advanced over denoising steps.
Args:
tensor_simple: Low-detail padding result
tensor_advanced: High-detail padding result
current_step: Current denoising step
start_step, end_step: Denoising range
strategy: Interpolation curve type
transition_start, transition_end: Transition window (0.0-1.0)
sharpness: Curve adjustment
enabled: Master switch
Returns:
Blended tensor
"""
if not enabled:
return tensor_advanced
# BUG FIX 6b: normalise strategy to a supported value before use.
_KNOWN_STRATEGIES = {BlendStrategy.LINEAR, BlendStrategy.COSINE,
BlendStrategy.EXPONENTIAL, BlendStrategy.SIGMOID}
_STRATEGY_ALIASES = {
'linear': BlendStrategy.LINEAR,
'cosine': BlendStrategy.COSINE,
'exponential': BlendStrategy.EXPONENTIAL,
'sigmoid': BlendStrategy.SIGMOID,
}
if isinstance(strategy, str):
strategy_key = strategy.lower()
if strategy_key not in _STRATEGY_ALIASES:
print(f"[MultiRes] Warning: unsupported strategy '{strategy}' "
f"β€” falling back to 'cosine'. "
f"Supported: {sorted(_STRATEGY_ALIASES.keys())}")
strategy = BlendStrategy.COSINE
else:
strategy = _STRATEGY_ALIASES[strategy_key]
elif strategy not in _KNOWN_STRATEGIES:
print(f"[MultiRes] Warning: unknown strategy {strategy!r} β€” falling back to cosine")
strategy = BlendStrategy.COSINE
total_steps = end_step - start_step
if total_steps <= 0:
return tensor_advanced
step_frac = (current_step - start_step) / total_steps
step_frac = max(0.0, min(1.0, step_frac))
if step_frac < transition_start:
local_progress = 0.0
elif step_frac > transition_end:
local_progress = 1.0
else:
duration = transition_end - transition_start
if duration <= 0:
local_progress = 1.0
else:
local_progress = (step_frac - transition_start) / duration
strat = MultiResStrategy(strategy)
alpha = strat.get_factor(local_progress, sharpness)
if alpha <= 0.001:
return tensor_simple
if alpha >= 0.999:
return tensor_advanced
# Standard lerp (temporal blend, not spatial)
return tensor_simple * (1.0 - alpha) + tensor_advanced * alpha
# =======================================================================
# 8. HELPER FUNCTIONS (From v13 for compatibility)
# =======================================================================
def create_circular_mask(h, w, center_x=0.5, center_y=0.5, radius=0.5,
device='cpu', dtype=torch.float32):
"""
Creates circular mask (white circle on black background).
βœ… FLOAT16 FIX: Safe sqrt
NOTE: This is v13 version (radial distance mask).
Different from v1/exp which don't have this function.
"""
eps_val = get_safe_epsilon(dtype)
# Create coordinate grid
y, x = torch.meshgrid(
torch.linspace(-1, 1, h, device=device, dtype=dtype),
torch.linspace(-1, 1, w, device=device, dtype=dtype),
indexing='ij'
)
# Shift center
x = x - (center_x - 0.5) * 2
y = y - (center_y - 0.5) * 2
# Calculate distance from center with protection
dist = torch.sqrt(x*x + y*y + eps_val)
# Create soft mask (smooth edges 0.1)
mask = 1.0 - torch.clamp((dist - (radius - 0.1)) / 0.2, 0, 1)
# Add channel and batch dimensions
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
return mask
def create_fade_to_black_mask(h, w, strength=0.1, device='cpu', dtype=torch.float32):
"""
Creates vignette (darkening towards edges).
βœ… FLOAT16 FIX: Safe sqrt
NOTE: This is v13 version (radial vignette).
Different from v1/exp which don't have this function.
"""
eps_val = get_safe_epsilon(dtype)
y, x = torch.meshgrid(
torch.linspace(-1, 1, h, device=device, dtype=dtype),
torch.linspace(-1, 1, w, device=device, dtype=dtype),
indexing='ij'
)
# sqrt with protection
dist = torch.sqrt(x*x + y*y + eps_val)
# Normalize so corners are 1.0 (max distance ~1.41)
dist = dist / 1.4142
# Invert: center white (1), edges black (0)
threshold = 1.0 - strength
mask = 1.0 - torch.clamp((dist - threshold) / strength, 0, 1)
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
return mask
# =======================================================================
# 9. PARAMETER VALIDATION HELPERS
# =======================================================================
def validate_blend_params(params):
"""Extract and validate blend parameters from dict.
Unsupported falloff values are normalised here with a warning so callers
never silently receive a mode the backend cannot honour.
Supported: 'linear', 'smoothstep', 'cosine'
"""
falloff = params.get('blend_falloff', 'smoothstep')
_SUPPORTED_FALLOFFS = {'linear', 'smoothstep', 'cosine'}
if falloff not in _SUPPORTED_FALLOFFS:
print(f"[validate_blend_params] Warning: unsupported blend_falloff '{falloff}' "
f"β€” falling back to 'smoothstep'. Supported: {sorted(_SUPPORTED_FALLOFFS)}")
falloff = 'smoothstep'
return {
'strength': float(params.get('blend_strength', 0.5)),
'width': int(params.get('blend_width', 0)) if params.get('blend_width') else None,
'falloff': falloff,
'sharpness': float(params.get('blend_sharpness', 1.0)),
'fade_to_black': bool(params.get('blend_fade_to_black', False)),
'fade_strength': float(params.get('blend_fade_strength', 0.1))
}
def validate_multires_params(params):
"""Extract and validate multi-resolution parameters from dict.
Unsupported strategy values are normalised here with a warning.
Supported: 'linear', 'cosine', 'exponential', 'sigmoid'
"""
strategy = params.get('multires_strategy', 'cosine')
_SUPPORTED_STRATEGIES = {'linear', 'cosine', 'exponential', 'sigmoid'}
if strategy not in _SUPPORTED_STRATEGIES:
print(f"[validate_multires_params] Warning: unsupported multires_strategy '{strategy}' "
f"β€” falling back to 'cosine'. Supported: {sorted(_SUPPORTED_STRATEGIES)}")
strategy = 'cosine'
return {
'strategy': strategy,
'transition_start': float(params.get('multires_start', 0.0)),
'transition_end': float(params.get('multires_end', 0.3)),
'sharpness': float(params.get('multires_sharpness', 1.0))
}
# =======================================================================
# vUltimate - End of File
# CRITICAL FIXES APPLIED:
# βœ… Fix 1: Infinite recursion in get_safe_epsilon (v13 bug at line 51/65)
# βœ… Fix 2: Correct v11+ mask semantics (advanced first, mask=1.0 on edges)
# βœ… Fix 3: Cache key includes dtype (v11+ improvement)
# βœ… Fix 4: Safe reflect with size validation
# βœ… Fix 5: v13 fade_to_black logic (not v7 logic)
# βœ… Fix 6: v13 circular/vignette masks (not v1/exp)
# =======================================================================