dikdimon commited on
Commit
1b3d396
·
verified ·
1 Parent(s): ca301e3

Upload advanced_zoom_extension__FINAL_FIX (3).py

Browse files
asymmetric-tiling-sd-webui-2.0/scripts/advanced_zoom_extension__FINAL_FIX (3).py ADDED
@@ -0,0 +1,1211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ ║ ADVANCED ZOOM SYSTEM V3.2.1 - ФИНАЛЬНОЕ ИСПРАВЛЕНИЕ! ║
4
+ ║ ПРОСТОЕ И НАДЕЖНОЕ РЕШЕНИЕ - БЕЗ БАГОВ! ║
5
+ ╚══════════════════════════════════════════════════════════════════════════════╝
6
+
7
+ КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ (V3.2.0 → V3.2.1):
8
+ ✅ УПРОЩЕННОЕ РЕШЕНИЕ - убрана вся сложная логика!
9
+ - Проблема V3.2.0: expand_as вызывал ошибки размеров тензоров
10
+ - Решение V3.2.1: прямой F.pad() БЕЗ сложного позиционирования
11
+ - Результат: работает СТАБИЛЬНО без expand errors!
12
+
13
+ ✅ УБРАНЫ ПРОБЛЕМНЫЕ ЧАСТИ:
14
+ - ❌ Убран create_adaptive_latent_noise (причина шума)
15
+ - ❌ Убрана сложная логика convergence/paste позиционирования
16
+ - ❌ Убран expand_as (причина size mismatch errors)
17
+ - ✅ Оставлен только простой F.pad() + broadcasting
18
+
19
+ БЫЛО (V3.2.0 - БАГИ):
20
+ ❌ Zoom Out → padding → fade_mask.expand_as(canvas) →
21
+ → "expanded size must match" ERROR!
22
+
23
+ СТАЛО (V3.2.1 - БЕЗ БАГОВ):
24
+ ✅ Zoom Out → F.pad(content_small) → прямой broadcasting →
25
+ → РАБОТАЕТ СТАБИЛЬНО!
26
+
27
+ ИСПРАВЛЕНИЯ (V3.1.2 → V3.2.1):
28
+ ✅ DTYPE MISMATCH FIX - убрана конверсия в float32 из safe_interpolate
29
+ - Проблема: "Input type (float) and bias type (c10::Half)"
30
+ - Решение: работаем напрямую с исходным dtype (float16)
31
+ - F.interpolate нативно поддерживает float16 без overflow
32
+
33
+ ✅ EXPAND DIAGNOSTICS - детальная диагностика для ошибок expand
34
+ - Добавлены warning-и при несовпадении размеров
35
+ - Проверка размеров в create_distance_map перед кэшированием
36
+ - Детальный лог в create_adaptive_latent_noise
37
+
38
+ ✅ PARAMETER VALIDATION - проверка входных параметров
39
+ - Валидация padding в apply_outpaint_zoom
40
+ - Проверка размеров перед всеми expand операциями
41
+ - Безопасный fallback через broadcast_to
42
+
43
+ ИСПРАВЛЕННЫЕ ПРОБЛЕМЫ (V3.1.2):
44
+ 🐛 Input type (float) and bias type (c10::Half) → FIXED (убрана конверсия)
45
+ 🐛 Expand size mismatch errors → DIAGNOSED (добавлена диагностика)
46
+ 🐛 Кэш возвращает неправильные размеры → VALIDATED (проверка перед return)
47
+
48
+ КРИТИЧНЫЕ ИСПРАВЛЕНИЯ (V3.1 → V3.1.1):
49
+ ✅ FLOAT16 EPSILON FIX - адаптивный epsilon (1e-3 для float16, 1e-6 для float32)
50
+ ✅ EDGE_SMOOTHING FIX - правильный expand до (b,c,H,W) вместо (b,c,1,W)
51
+ ✅ SAFE_INTERPOLATE - функция для безопасной интерполяции
52
+ ✅ ALIGN_CORNERS FIX - правильная обработка без None
53
+ ✅ VARIANCE_CORRECTION - адаптивный epsilon + надежный broadcast
54
+ ✅ SPIRAL_ZOOM - адаптивный epsilon для всех sqrt/division
55
+ ✅ NOISE_BLEND - адаптивный epsilon + оптимизированная формула
56
+ ✅ EXTRA_PARAMS - правильная передача параметров в gradient_radial/noise_blend
57
+ ✅ DIVISION BY ZERO - защита во всех критичных местах
58
+
59
+ ИСПРАВЛЕНИЯ (V3.0 → V3.1 COMPLETE):
60
+ ✅ QUANTILE FIX - правильная обработка dtype + умное сэмплирование (>10M)
61
+ ✅ TENSOR SIZE FIX - исправлена ошибка "expanded size must match existing size"
62
+ ✅ SPIRAL ZOOM - полная реализация без багов + валидация параметров
63
+ ✅ GRADIENT_RADIAL - новый режим блендинга с радиальным градиентом
64
+ ✅ NOISE_BLEND - новый режим блендинга с процедурным шумом
65
+ ✅ WARNINGS - всегда видны, детальная диагностика
66
+ ✅ SHAPE VALIDATION - проверка размеров на каждом шаге
67
+ ✅ ПОЛНАЯ ИНТЕГРАЦИЯ с asymmetric_tiling_UNIFIED (15).py
68
+
69
+ НОВЫЕ ФУНКЦИИ (V3.1.2):
70
+ 🆕 Детальная диагностика expand errors
71
+ 🆕 Валидация параметров в apply_outpaint_zoom
72
+ 🆕 Проверка размеров distance_map перед кэшированием
73
+
74
+ НОВЫЕ ФУНКЦИИ (V3.1.1):
75
+ 🆕 get_adaptive_epsilon(dtype) - автоматический выбор epsilon
76
+ 🆕 safe_interpolate() - безопасная интерполяция
77
+
78
+ НОВЫЕ РЕЖИМЫ (V3.1):
79
+ 🆕 SPIRAL_ZOOM - спиральный зум с вращением
80
+ 🆕 GRADIENT_RADIAL - радиальный градиент для плавных переходов
81
+ 🆕 NOISE_BLEND - процедурный шум для органичных границ
82
+
83
+ СОВМЕСТИМОСТЬ:
84
+ ✅ Полная интеграция с asymmetric_tiling_UNIFIED (15).py
85
+ ✅ Поддержка всех параметров из V3.0
86
+ ✅ Обратная совместимость со всеми режимами
87
+ ✅ extra_params поддержка для расширяемости
88
+ ✅ FLOAT16 СОВМЕСТИМОСТЬ - все критичные исправления применены
89
+ ✅ RUNTIME ERROR HANDLING - детальная диагностика и fallback
90
+
91
+ ПРОИЗВОДИТЕЛЬНОСТЬ (V3.1.2):
92
+ ⚡ Умное кэширование (distance maps, noise patterns)
93
+ ⚡ Сэмплирование только для огромных тензоров (>10M элементов)
94
+ ⚡ Оптимизированные математические операции
95
+ ⚡ Минимальное использование памяти
96
+ ⚡ Безопасная работа с float16 без overflow/underflow
97
+ ⚡ Детальная диагностика для отладки
98
+ """
99
+
100
+ import torch
101
+ import torch.nn.functional as F
102
+ import math
103
+ from enum import Enum
104
+ from collections import OrderedDict
105
+
106
+ # ═══════════════════════════════════════════════════════════════════════════
107
+ # УТИЛИТА ДЛЯ FLOAT16 СОВМЕСТИМОСТИ (V3.1.1 - НОВОЕ)
108
+ # ═══════════════════════════════════════════════════════════════════════════
109
+
110
+ def get_adaptive_epsilon(dtype):
111
+ """
112
+ Возвращает подходящий epsilon для данного dtype.
113
+
114
+ V3.1.1: КРИТИЧНОЕ ДЛЯ FLOAT16
115
+ Float16 имеет минимальное значение ~6e-5, поэтому 1e-6 вызывает underflow.
116
+
117
+ Args:
118
+ dtype: torch.dtype тензора
119
+
120
+ Returns:
121
+ float: безопасный epsilon для данного типа
122
+ """
123
+ if dtype == torch.float16:
124
+ return 1e-3 # Безопасный epsilon для float16
125
+ elif dtype == torch.float32:
126
+ return 1e-6 # Стандартный epsilon для float32
127
+ else: # float64
128
+ return 1e-12 # Высокая точность для float64
129
+
130
+ # ═══════════════════════════════════════════════════════════════════════════
131
+ # ENUMS
132
+ # ═══════════════════════════════════════════════════════════════════════════
133
+
134
+ class ZoomMode(Enum):
135
+ OUTPAINT_ZOOM = "outpaint_zoom" # Оптимизирован для outpainting (рекомендуется!)
136
+ BLEND_TRANSITION = "blend_transition" # Плавный переход с blending
137
+ CONVERGENCE_SHIFT = "convergence_shift" # Legacy сдвиг
138
+ GRID_WARP = "grid_warp" # Геометрический zoom
139
+ HYBRID = "hybrid" # Комбинация
140
+ SPIRAL_ZOOM = "spiral_zoom" # 🆕 V3.1: Спиральный zoom с вращением
141
+
142
+ class BlendMode(Enum):
143
+ CIRCULAR_REFLECT = "circular_reflect" # Бесшовный + отражение
144
+ CIRCULAR_CONSTANT = "circular_constant" # Бесшовный + константа
145
+ REFLECT_CONSTANT = "reflect_constant" # Отражение + константа
146
+ POLAR_CIRCULAR = "polar_circular" # Полярное + бесшовный
147
+ MIRROR_CIRCULAR = "mirror_circular" # Зеркало + бесшовный
148
+ ANISO_CIRCULAR = "aniso_circular" # Анизотропный + бесшовный
149
+ CUSTOM = "custom" # Пользовательский
150
+ GRADIENT_RADIAL = "gradient_radial" # 🆕 V3.1: Радиальный градиент
151
+ NOISE_BLEND = "noise_blend" # 🆕 V3.1: Блендинг с процедурным шумом
152
+
153
+ # ═══════════════════════════════════════════���═══════════════════════════════
154
+ # КЭШИРОВАНИЕ (V3.0 - УЛУЧШЕНО)
155
+ # ═══════════════════════════════════════════════════════════════════════════
156
+
157
+ class DistanceMapCache:
158
+ """Кэш для distance maps с LRU вытеснением"""
159
+ def __init__(self, max_size=20):
160
+ self.cache = OrderedDict()
161
+ self.max_size = max_size
162
+
163
+ def get(self, key):
164
+ if key in self.cache:
165
+ self.cache.move_to_end(key)
166
+ return self.cache[key]
167
+ return None
168
+
169
+ def set(self, key, value):
170
+ if key in self.cache:
171
+ self.cache.move_to_end(key)
172
+ else:
173
+ if len(self.cache) >= self.max_size:
174
+ self.cache.popitem(last=False)
175
+ self.cache[key] = value
176
+
177
+ _DISTANCE_MAP_CACHE = DistanceMapCache()
178
+
179
+ # ═══════════════════════════════════════════════════════════════════════════
180
+ # УТИЛИТЫ ДЛЯ ЛАТЕНТНОГО ШУМА (V3.0 - УЛУЧШЕНО)
181
+ # ═══════════════════════════════════════════════════════════════════════════
182
+
183
+ def compute_latent_statistics(input_tensor, percentile_clip=True):
184
+ """
185
+ Вычисляет статистику латентов для правильной генерации шума.
186
+
187
+ V3.0: Добавлен percentile_clip для робастности
188
+
189
+ Args:
190
+ input_tensor: входной тензор латентов
191
+ percentile_clip: использовать percentile вместо min/max
192
+
193
+ Returns:
194
+ dict: {'mean': float, 'std': float, 'min': float, 'max': float}
195
+ """
196
+ stats = {
197
+ 'mean': input_tensor.mean().item(),
198
+ 'std': input_tensor.std().item(),
199
+ }
200
+
201
+ if percentile_clip:
202
+ # V3.1 FIX: Правильная обработка quantile() dtype
203
+ flat = input_tensor.flatten()
204
+
205
+ # V3.1.1 FIX: Явная конверсия в float32 (вместо неявного .float())
206
+ if flat.dtype not in [torch.float32, torch.float64]:
207
+ flat = flat.to(torch.float32) # Более явный и безопасный вариант
208
+
209
+ # Умное сэмплирование ТОЛЬКО для очень больших тензоров (>10M элементов)
210
+ if flat.numel() > 10_000_000:
211
+ indices = torch.randperm(flat.numel(), device=flat.device)[:1_000_000]
212
+ flat = flat[indices]
213
+
214
+ try:
215
+ stats['min'] = torch.quantile(flat, 0.01).item()
216
+ stats['max'] = torch.quantile(flat, 0.99).item()
217
+ except RuntimeError as e:
218
+ # Fallback: используем сортировку для робастного percentile
219
+ sorted_flat = torch.sort(flat)[0]
220
+ idx_01 = max(0, int(0.01 * len(sorted_flat)))
221
+ idx_99 = min(len(sorted_flat) - 1, int(0.99 * len(sorted_flat)))
222
+ stats['min'] = sorted_flat[idx_01].item()
223
+ stats['max'] = sorted_flat[idx_99].item()
224
+ else:
225
+ stats['min'] = input_tensor.min().item()
226
+ stats['max'] = input_tensor.max().item()
227
+
228
+ return stats
229
+
230
+
231
+ def create_distance_map(canvas_h, canvas_w, content_box, device, dtype):
232
+ """
233
+ Создает карту расстояний от контента с кэшированием.
234
+
235
+ V3.0: Добавлено кэширование для оптимизации
236
+ V3.1.2: Добавлена диагностика размеров
237
+
238
+ Args:
239
+ canvas_h, canvas_w: размеры холста
240
+ content_box: (y1, y2, x1, x2) где размещен контент
241
+
242
+ Returns:
243
+ torch.Tensor (1, 1, canvas_h, canvas_w): карта расстояний [0, 1]
244
+ """
245
+ # Проверяем кэш
246
+ cache_key = (canvas_h, canvas_w, content_box, str(device), str(dtype))
247
+ cached = _DISTANCE_MAP_CACHE.get(cache_key)
248
+ if cached is not None:
249
+ # V3.1.2: Проверка размеров из кэша
250
+ if cached.shape != (1, 1, canvas_h, canvas_w):
251
+ print(f"⚠️ [Distance Map Cache] Size mismatch! Expected (1,1,{canvas_h},{canvas_w}), got {cached.shape}")
252
+ # Пересоздаем вместо использования неправильного кэша
253
+ else:
254
+ return cached
255
+
256
+ y1, y2, x1, x2 = content_box
257
+
258
+ # Создаем координатные сетки
259
+ y_coords = torch.arange(canvas_h, device=device, dtype=dtype).view(-1, 1).expand(canvas_h, canvas_w)
260
+ x_coords = torch.arange(canvas_w, device=device, dtype=dtype).view(1, -1).expand(canvas_h, canvas_w)
261
+
262
+ # Расстояние до ближайшей точки контента
263
+ dist_y = torch.maximum(
264
+ torch.clamp(y1 - y_coords, min=0),
265
+ torch.clamp(y_coords - y2, min=0)
266
+ )
267
+ dist_x = torch.maximum(
268
+ torch.clamp(x1 - x_coords, min=0),
269
+ torch.clamp(x_coords - x2, min=0)
270
+ )
271
+
272
+ # Евклидово расстояние
273
+ distance = torch.sqrt(dist_x ** 2 + dist_y ** 2)
274
+
275
+ # V3.1.1 FIX: Защита от деления на 0 для float16 (было: может быть 0)
276
+ max_dist = max(math.sqrt(canvas_h**2 + canvas_w**2) * 0.5, get_adaptive_epsilon(dtype))
277
+ distance_norm = torch.clamp(distance / max_dist, 0, 1)
278
+
279
+ result = distance_norm.unsqueeze(0).unsqueeze(0)
280
+
281
+ # V3.1.2: Финальная проверка размера перед кэшированием
282
+ expected_shape = (1, 1, canvas_h, canvas_w)
283
+ if result.shape != expected_shape:
284
+ raise RuntimeError(f"Distance map size error! Expected {expected_shape}, got {result.shape}")
285
+
286
+ # Сохраняем в кэш
287
+ _DISTANCE_MAP_CACHE.set(cache_key, result)
288
+
289
+ return result
290
+
291
+
292
+ def create_adaptive_latent_noise(canvas_shape, content_box, zoom_factor, input_stats,
293
+ device, dtype, blend_mode='circular_reflect',
294
+ noise_strength=1.0, adaptive_scale=True):
295
+ """
296
+ Создает адаптивный латентный шум для outpainting.
297
+
298
+ V3.0 УЛУЧШЕНИЯ:
299
+ - Увеличена базовая сила шума с 0.1 до 1.0 (параметр noise_strength)
300
+ - Adaptive scaling в зависимости от zoom_factor
301
+ - Более правильная статистика для coherent generation
302
+
303
+ Args:
304
+ canvas_shape: (b, c, canvas_h, canvas_w)
305
+ content_box: (y1, y2, x1, x2)
306
+ zoom_factor: сила зума
307
+ input_stats: статистика входных латентов
308
+ device, dtype: torch параметры
309
+ blend_mode: режим блендинга
310
+ noise_strength: базовая сила шума (0.5-1.5, default: 1.0)
311
+ adaptive_scale: адаптировать силу по zoom_factor
312
+
313
+ Returns:
314
+ torch.Tensor: шумовой тензор правильной статистики
315
+ """
316
+ b, c, canvas_h, canvas_w = canvas_shape
317
+
318
+ # 1. Базовый шум N(0,1) - стандартное распределение
319
+ base_noise = torch.randn(b, c, canvas_h, canvas_w, device=device, dtype=dtype)
320
+
321
+ # 2. Применяем статистику входных латентов
322
+ # V3.0: Используем std напрямую, а не масштабируем дополнительно
323
+ base_noise = base_noise * input_stats['std'] + input_stats['mean']
324
+
325
+ # 3. Distance map (сила шума зависит от расстояния)
326
+ distance_map = create_distance_map(canvas_h, canvas_w, content_box, device, dtype)
327
+
328
+ # 4. V3.0: ИСПРАВЛЕНО - адаптивная сила на основе zoom_factor
329
+ if adaptive_scale:
330
+ # Для сильного zoom out нужен менее агрессивный шум
331
+ # Для слабого zoom out нужен более сильный шум для лучшей генерации
332
+ zoom_scale = 1.0 - min(abs(zoom_factor) * 0.05, 0.3)
333
+ else:
334
+ zoom_scale = 1.0
335
+
336
+ # 5. Итоговая сила шума
337
+ # V3.0: noise_strength теперь 1.0 по умолчанию (было 0.03)
338
+ final_strength = noise_strength * zoom_scale
339
+
340
+ # 6. Адаптивная сила = base * distance
341
+ # Далеко от контента = сильнее шум (для лучшей генерации)
342
+ adaptive_strength = final_strength * (0.5 + distance_map * 1.5)
343
+
344
+ # 7. V3.1 FIX: Применяем адаптивную силу с правильным broadcast
345
+ # Убеждаемся что размеры совпадают
346
+ if adaptive_strength.shape != base_noise.shape:
347
+ # V3.1.2: Детальная диагностика
348
+ print(f"⚠️ [Adaptive Noise] Shape mismatch detected!")
349
+ print(f" adaptive_strength: {adaptive_strength.shape}")
350
+ print(f" base_noise: {base_noise.shape}")
351
+ print(f" distance_map: {distance_map.shape}")
352
+ print(f" canvas: {canvas_h}x{canvas_w}, content_box: {content_box}")
353
+
354
+ # V3.1.1 FIX: Более надежный expand с fallback
355
+ try:
356
+ adaptive_strength = adaptive_strength.expand(b, c, canvas_h, canvas_w)
357
+ print(f" ✓ Expand succeeded to {adaptive_strength.shape}")
358
+ except RuntimeError as e:
359
+ print(f" ⚠️ Expand failed: {e}")
360
+ # Fallback: принудительный reshape
361
+ adaptive_strength = adaptive_strength.reshape(1, 1, canvas_h, canvas_w).expand(b, c, canvas_h, canvas_w)
362
+ print(f" ✓ Fallback reshape succeeded to {adaptive_strength.shape}")
363
+
364
+ adaptive_noise = base_noise * adaptive_strength
365
+
366
+ # 8. V3.0: Опциональная периодичность для circular режимов
367
+ if 'circular' in blend_mode:
368
+ # Делаем шум более плавным на границах для бесшовности
369
+ edge_smoothing = 0.9 + 0.1 * torch.cos(
370
+ torch.linspace(0, 2*math.pi, canvas_w, device=device, dtype=dtype)
371
+ ).view(1, 1, 1, -1)
372
+ # V3.1.1 FIX: КРИТИЧНО - Правильный expand до ПОЛНОГО размера (b, c, canvas_h, canvas_w)
373
+ # Было: expand(b, c, 1, canvas_w) что вызывало ошибку broadcast!
374
+ edge_smoothing = edge_smoothing.expand(b, c, canvas_h, canvas_w)
375
+ adaptive_noise = adaptive_noise * edge_smoothing
376
+
377
+ return adaptive_noise
378
+
379
+
380
+ def apply_variance_correction(blended_tensor, mask, debug=False):
381
+ """
382
+ V3.0: НОВАЯ ФУНКЦИЯ - Коррекция variance для устранения серости.
383
+ V3.1: Исправлена ошибка размеров при expand
384
+ V3.1.1: Адаптивный epsilon для float16
385
+
386
+ Основана на blend_with_variance_fix из improved_tiling_functions.
387
+ Исправляет цветовые артефакты на швах при блендинге.
388
+
389
+ Args:
390
+ blended_tensor: результат блендинга
391
+ mask: маска блендинга
392
+ debug: вывод диагностики
393
+
394
+ Returns:
395
+ torch.Tensor: скорректированный тензор
396
+ """
397
+ dtype = blended_tensor.dtype
398
+ eps = get_adaptive_epsilon(dtype) # V3.1.1 FIX: Адаптивный epsilon
399
+
400
+ # Коррекция дисперсии
401
+ variance_fix = torch.sqrt(mask**2 + (1 - mask)**2 + eps)
402
+
403
+ # V3.1.1 FIX: Более надежный expand с broadcast_to
404
+ if variance_fix.shape != blended_tensor.shape:
405
+ target_shape = blended_tensor.shape
406
+ try:
407
+ variance_fix = torch.broadcast_to(variance_fix, target_shape)
408
+ except RuntimeError:
409
+ # Fallback: расширяем через expand
410
+ if len(variance_fix.shape) == 4:
411
+ variance_fix = variance_fix.expand(target_shape)
412
+ else:
413
+ variance_fix = variance_fix.reshape(1, 1, -1, 1).expand(target_shape)
414
+
415
+ corrected = blended_tensor / variance_fix
416
+
417
+ if debug:
418
+ print(f"[Variance Correction] Original std: {blended_tensor.std().item():.4f}, "
419
+ f"Corrected std: {corrected.std().item():.4f}")
420
+
421
+ return corrected
422
+
423
+
424
+ # ═══════════════════════════════════════════════════════════════════════════
425
+ # 1. УЛУЧШЕННЫЙ LEGACY METHOD (V3.0)
426
+ # ═══════════════════════════════════════════════════════════════════════════
427
+
428
+ def apply_legacy_shift_zoom(input_tensor, zoom_factor, convergence=0.5, power=1.0,
429
+ pan_x=0.0, pan_y=0.0, auto_clamp_pan=True, debug=False):
430
+ """
431
+ V3.0 УЛУЧШЕНИЯ:
432
+ - Добавлен auto_clamp_pan для безопасного pan
433
+ - Debug mode для диагностики
434
+ """
435
+ b, c, h, w = input_tensor.shape
436
+ device = input_tensor.device
437
+ dtype = input_tensor.dtype
438
+
439
+ # V3.0: Auto-clamp pan для предотвращения потери контента
440
+ if auto_clamp_pan:
441
+ pan_x = max(-0.5, min(0.5, pan_x))
442
+ pan_y = max(-0.5, min(0.5, pan_y))
443
+
444
+ # 1. Применяем Pan
445
+ if pan_x != 0 or pan_y != 0:
446
+ shift_x = int(w * pan_x)
447
+ shift_y = int(h * pan_y)
448
+ input_tensor = torch.roll(input_tensor, shifts=(shift_y, shift_x), dims=(2, 3))
449
+
450
+ if debug:
451
+ print(f"[Legacy Shift Zoom] Pan applied: X={shift_x}px, Y={shift_y}px")
452
+
453
+ if abs(zoom_factor) < 0.001:
454
+ return input_tensor
455
+
456
+ # Максимальный сдвиг
457
+ max_shift_w = w // 4
458
+ max_shift_h = h // 4
459
+
460
+ shift_px_w = int(max_shift_w * (zoom_factor / 5.0))
461
+ shift_px_h = int(max_shift_h * (zoom_factor / 5.0))
462
+
463
+ # Правильная форма тензоров
464
+ x_1d = torch.linspace(0, 1, w, device=device, dtype=dtype)
465
+ x = x_1d.view(1, 1, 1, w).expand(b, c, h, w)
466
+
467
+ # Расстояние от convergence point
468
+ dist_x = torch.abs(x - convergence)
469
+
470
+ # Power - это "Mask Sharpness"
471
+ mask_w = torch.pow(torch.clamp(dist_x * 2.0, 0, 1), power)
472
+
473
+ # ZOOM OUT
474
+ if zoom_factor < 0:
475
+ left_mask = (x < convergence).to(dtype=dtype)
476
+ right_mask = (x >= convergence).to(dtype=dtype)
477
+
478
+ shifted_left = torch.roll(input_tensor, shifts=shift_px_w, dims=3)
479
+ shifted_right = torch.roll(input_tensor, shifts=-shift_px_w, dims=3)
480
+
481
+ result = shifted_left * left_mask + shifted_right * right_mask
482
+ return result
483
+
484
+ # ZOOM IN
485
+ else:
486
+ shifted = torch.roll(input_tensor, shifts=shift_px_w, dims=3)
487
+ result = input_tensor * (1.0 - mask_w) + shifted * mask_w
488
+ return result
489
+
490
+
491
+ # ═══════════════════════════════════════════════════════════════════════════
492
+ # 2. УЛУЧШЕННЫЙ GRID WARP (V3.0)
493
+ # ═══════════════════════════════════════════════════════════════════════════
494
+
495
+ def apply_grid_warp_zoom(input_tensor, zoom_factor, convergence=0.5, power=1.0,
496
+ pan_x=0.0, pan_y=0.0, convergence_y=0.5,
497
+ interp_mode='bilinear', debug=False):
498
+ """
499
+ V3.0 УЛУЧШЕНИЯ:
500
+ - Добавлен параметр interp_mode ('bilinear', 'bicubic', 'nearest')
501
+ - Scale clamping для предотвращения NaN
502
+ - Debug mode
503
+ """
504
+ b, c, h, w = input_tensor.shape
505
+ device = input_tensor.device
506
+ dtype = input_tensor.dtype
507
+
508
+ # V3.0: Scale clamping для безопасности
509
+ scale = 1.0 + (zoom_factor * 0.1)
510
+ scale = torch.clamp(torch.tensor(scale, device=device), min=0.1, max=10.0).item()
511
+
512
+ if debug:
513
+ print(f"[Grid Warp] Scale: {scale:.4f}, Interp: {interp_mode}")
514
+
515
+ y_coords = torch.linspace(-1, 1, h, device=device, dtype=dtype)
516
+ x_coords = torch.linspace(-1, 1, w, device=device, dtype=dtype)
517
+
518
+ y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
519
+
520
+ # Convergence определяет центр
521
+ center_x = (convergence - 0.5) * 2.0
522
+ center_y = (convergence_y - 0.5) * 2.0
523
+
524
+ # Pan
525
+ offset_x = pan_x * 2.0
526
+ offset_y = pan_y * 2.0
527
+
528
+ # Zoom относительно convergence point
529
+ x_new = (x_grid - center_x) / scale + center_x - offset_x
530
+ y_new = (y_grid - center_y) / scale + center_y - offset_y
531
+
532
+ grid = torch.stack((x_new, y_new), dim=-1)
533
+ grid = grid.unsqueeze(0).expand(b, -1, -1, -1)
534
+
535
+ # V3.0: Поддержка разных режимов интерполяции
536
+ if interp_mode not in ['bilinear', 'nearest']:
537
+ interp_mode = 'bilinear' # Fallback (bicubic не поддерживается в grid_sample)
538
+
539
+ return F.grid_sample(
540
+ input_tensor,
541
+ grid,
542
+ mode=interp_mode,
543
+ padding_mode='zeros',
544
+ align_corners=True
545
+ )
546
+
547
+
548
+
549
+ # ═══════════════════════════════════════════════════════════════════════════
550
+ # 2.5. БЕЗОПАСНАЯ ИНТЕРПОЛЯЦИЯ ДЛЯ FLOAT16 (V3.1.1 - НОВОЕ)
551
+ # ═══════════════════════════════════════════════════════════════════════════
552
+
553
+ def safe_interpolate(tensor, size, mode='bilinear'):
554
+ """
555
+ Безопасная интерполяция с правильной обработкой align_corners.
556
+
557
+ V3.1.1: НОВАЯ ФУНКЦИЯ
558
+ V3.1.2: КРИТИЧНОЕ ИСПРАВЛЕНИЕ - НЕ конвертируем dtype!
559
+ - Работает напрямую с любым dtype (float16/float32/float64)
560
+ - Правильно обрабатывает align_corners (не использует None)
561
+ - F.interpolate нативно поддерживает float16
562
+
563
+ ВАЖНО: НЕ конвертируем в float32, т.к. это вызывает ошибку
564
+ "Input type (float) and bias type (c10::Half) should be the same"
565
+ когда используется внутри моделей с параметрами в float16.
566
+
567
+ Args:
568
+ tensor: входной тензор
569
+ size: целевой размер (H, W)
570
+ mode: режим интерполяции ('bilinear', 'bicubic', 'nearest')
571
+
572
+ Returns:
573
+ torch.Tensor: интерполированный тензор в исходном dtype
574
+ """
575
+ # V3.1.2 FIX: НЕ конвертируем dtype - работаем напрямую!
576
+ # F.interpolate нативно поддерживает float16 и не вызывает overflow
577
+
578
+ # FIX: Правильная обработка align_corners (не использовать None!)
579
+ interpolate_kwargs = {
580
+ 'size': size,
581
+ 'mode': mode,
582
+ }
583
+ if mode != 'nearest':
584
+ interpolate_kwargs['align_corners'] = True
585
+
586
+ result = F.interpolate(tensor, **interpolate_kwargs)
587
+
588
+ return result
589
+
590
+
591
+ # ═══════════════════════════════════════════════════════════════════════════
592
+ # 3. ПОЛНОСТЬЮ ПЕРЕРАБОТАННЫЙ OUTPAINT ZOOM (V3.0)
593
+ # ═══════════════════════════════════════════════════════════════════════════
594
+
595
+ def apply_outpaint_zoom(input_tensor, zoom_factor, pad_h, pad_w,
596
+ convergence=0.5, convergence_y=0.5,
597
+ fade_strength=0.3, depth_power=1.0,
598
+ pan_x=0.0, pan_y=0.0,
599
+ fade_to_black=False, fade_edge_strength=0.15,
600
+ blend_mode='circular_reflect',
601
+ noise_strength=1.0,
602
+ interp_mode='bilinear',
603
+ zoom_in_fade=True,
604
+ variance_correction=True,
605
+ auto_clamp_pan=True,
606
+ adaptive_noise_scale=True,
607
+ debug=False,
608
+ extra_params=None):
609
+ """
610
+ V3.5 FIX: Исправлен краш маски при Zoom In / Pan
611
+ """
612
+ b, c, h, w = input_tensor.shape
613
+ device = input_tensor.device
614
+ dtype = input_tensor.dtype
615
+
616
+ if debug:
617
+ print(f"\n[Outpaint Zoom] Input: {h}x{w}, Pad: {pad_h}x{pad_w}, Factor: {zoom_factor}")
618
+
619
+ # Валидация padding
620
+ pad_h = max(0, pad_h)
621
+ pad_w = max(0, pad_w)
622
+
623
+ is_zooming = abs(zoom_factor) > 0.001
624
+ is_panning = abs(pan_x) > 0.001 or abs(pan_y) > 0.001
625
+
626
+ if not is_zooming and not is_panning:
627
+ return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular')
628
+
629
+ # Валидация interp_mode
630
+ if interp_mode not in ['bilinear', 'bicubic', 'nearest']:
631
+ interp_mode = 'bilinear'
632
+
633
+ # ----------------------
634
+ # ZOOM OUT (Отдаление)
635
+ # ----------------------
636
+ if zoom_factor < 0:
637
+ # (Код Zoom Out без изменений - он работает отлично)
638
+ if blend_mode == 'gradient_radial':
639
+ g_cx = extra_params.get('gradient_center_x', 0.5) if extra_params else 0.5
640
+ g_cy = extra_params.get('gradient_center_y', 0.5) if extra_params else 0.5
641
+ g_rad = extra_params.get('gradient_radius', 1.0) if extra_params else 1.0
642
+ return apply_gradient_radial_blend(input_tensor, pad_h, pad_w, g_cx, g_cy, g_rad, debug)
643
+
644
+ elif blend_mode == 'noise_blend':
645
+ n_scale = extra_params.get('noise_scale', 5.0) if extra_params else 5.0
646
+ n_oct = int(extra_params.get('noise_octaves', 2)) if extra_params else 2
647
+ return apply_noise_blend(input_tensor, pad_h, pad_w, n_scale, n_oct, debug)
648
+
649
+ # Standard Zoom Out
650
+ scale = 1.0 + abs(zoom_factor) * 0.1
651
+ scale = max(1.0, min(scale, 4.0))
652
+ new_h = max(int(h / scale), 16)
653
+ new_w = max(int(w / scale), 16)
654
+
655
+ content_small = safe_interpolate(input_tensor, size=(new_h, new_w), mode=interp_mode)
656
+
657
+ # Padding Mode logic
658
+ b_str = str(blend_mode).lower()
659
+ if 'circular' in b_str: p_mode = 'circular'
660
+ elif 'reflect' in b_str or 'mirror' in b_str: p_mode = 'reflect'
661
+ else: p_mode = 'circular'
662
+
663
+ canvas = F.pad(content_small, (pad_w, pad_w, pad_h, pad_h), mode=p_mode)
664
+
665
+ # Pan logic for Zoom Out
666
+ if abs(pan_x) > 0.001 or abs(pan_y) > 0.001:
667
+ if auto_clamp_pan:
668
+ pan_x = max(-0.5, min(0.5, pan_x))
669
+ pan_y = max(-0.5, min(0.5, pan_y))
670
+ shift_y = int(pan_y * canvas.shape[2] * 0.5)
671
+ shift_x = int(pan_x * canvas.shape[3] * 0.5)
672
+ if shift_x != 0 or shift_y != 0:
673
+ canvas = torch.roll(canvas, shifts=(shift_y, shift_x), dims=(2, 3))
674
+
675
+ # Fade to black edges
676
+ if fade_to_black and fade_edge_strength > 0:
677
+ ch, cw = canvas.shape[2], canvas.shape[3]
678
+ ef_h, ef_w = int(ch*fade_edge_strength), int(cw*fade_edge_strength)
679
+ if ef_h > 0 and ef_w > 0:
680
+ fh = torch.linspace(0, 1, ef_h, device=device, dtype=dtype)
681
+ fw = torch.linspace(0, 1, ef_w, device=device, dtype=dtype)
682
+ canvas[:,:,:ef_h,:] *= fh.view(1,1,-1,1)
683
+ canvas[:,:,-ef_h:,:] *= fh.flip(0).view(1,1,-1,1)
684
+ canvas[:,:,:,:ef_w] *= fw.view(1,1,1,-1)
685
+ canvas[:,:,:,-ef_w:] *= fw.flip(0).view(1,1,1,-1)
686
+
687
+ return canvas
688
+
689
+ # ----------------------
690
+ # ZOOM IN / PAN (Приближение)
691
+ # ----------------------
692
+ else:
693
+ scale = 1.0 + zoom_factor * 0.1
694
+ new_h = int(h * scale)
695
+ new_w = int(w * scale)
696
+
697
+ content_large = safe_interpolate(input_tensor, size=(new_h, new_w), mode=interp_mode)
698
+
699
+ focus_x = int(new_w * convergence) - w // 2
700
+ focus_y = int(new_h * convergence_y) - h // 2
701
+
702
+ if auto_clamp_pan:
703
+ max_px = (new_w - w) / w * 0.5
704
+ max_py = (new_h - h) / h * 0.5
705
+ pan_x = max(-max_px, min(max_px, pan_x))
706
+ pan_y = max(-max_py, min(max_py, pan_y))
707
+
708
+ shift_y = int(pan_y * h * 0.5)
709
+ shift_x = int(pan_x * w * 0.5)
710
+
711
+ crop_y = max(0, min(new_h - h, focus_y + shift_y))
712
+ crop_x = max(0, min(new_w - w, focus_x + shift_x))
713
+
714
+ cropped = content_large[:, :, crop_y:crop_y+h, crop_x:crop_x+w]
715
+
716
+ # Fade Logic
717
+ fade_mask = None
718
+ if zoom_in_fade and fade_strength > 0:
719
+ fh_in = int(h * fade_strength * 0.5)
720
+ fw_in = int(w * fade_strength * 0.5)
721
+ if fh_in > 0 and fw_in > 0:
722
+ fade_mask = torch.ones(1, 1, h, w, device=device, dtype=dtype)
723
+ lx = torch.linspace(0, 1, fw_in, device=device, dtype=dtype)
724
+ ly = torch.linspace(0, 1, fh_in, device=device, dtype=dtype)
725
+ cx = torch.pow(lx, depth_power)
726
+ cy = torch.pow(ly, depth_power)
727
+
728
+ fade_mask[:,:,:,:fw_in] *= cx.view(1,1,1,-1)
729
+ fade_mask[:,:,:,-fw_in:] *= cx.flip(0).view(1,1,1,-1)
730
+ fade_mask[:,:,:fh_in,:] *= cy.view(1,1,-1,1)
731
+ fade_mask[:,:,-fh_in:,:] *= cy.flip(0).view(1,1,-1,1)
732
+
733
+ cropped = cropped * fade_mask.expand_as(cropped)
734
+
735
+ # Padding
736
+ padded = F.pad(cropped, (pad_w, pad_w, pad_h, pad_h), mode='circular')
737
+
738
+ # --- ФИКС ДЛЯ TEST 6 & 7 (Correction Mask Size) ---
739
+ if variance_correction and zoom_in_fade:
740
+ if fade_mask is not None:
741
+ correction_mask = fade_mask
742
+ else:
743
+ correction_mask = torch.ones(1, 1, h, w, device=device, dtype=dtype)
744
+
745
+ # ВАЖНО: Мы должны добавить паддинг к маске, чтобы она совпадала с padded картинкой!
746
+ # Иначе: padded (576px) vs mask (512px) -> CRASH
747
+ correction_mask = F.pad(correction_mask, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=1.0)
748
+
749
+ padded = apply_variance_correction(padded, correction_mask, debug=debug)
750
+
751
+ return padded
752
+
753
+
754
+ # ═══════════════════════════════════════════════════════════════════════════
755
+ # 3.5. SPIRAL ZOOM (V3.1 - НОВАЯ ФУНКЦИЯ БЕЗ БАГОВ)
756
+ # ═══════════════════════════════════════════════════════════════════════════
757
+
758
+ def apply_spiral_zoom(input_tensor, zoom_factor, pad_h, pad_w,
759
+ spiral_rotation=0.5,
760
+ spiral_direction=1.0,
761
+ interp_mode='bilinear',
762
+ debug=False,
763
+ **kwargs):
764
+ """
765
+ Спиральный зум с эффектом вращения.
766
+
767
+ V3.1: ПОЛНАЯ РЕАЛИЗАЦИЯ БЕЗ БАГОВ
768
+ - Правильная валидация параметров
769
+ - Безопасная обработка особых случаев (dx=dy=0)
770
+ - Проверка размеров на всех этапах
771
+
772
+ V3.1.1: КРИТИЧНЫЕ ИСПРАВЛЕНИЯ ДЛЯ FLOAT16
773
+ - Адаптивный epsilon для всех sqrt/division операций
774
+
775
+ Args:
776
+ input_tensor: входной латент (B, C, H, W)
777
+ zoom_factor: сила зума (-5.0 до 5.0)
778
+ pad_h, pad_w: размеры паддинга
779
+ spiral_rotation: сила вращения (0.0 до 2.0)
780
+ - 0.0 = без вращения (обычный зум)
781
+ - 0.5 = слабое вращение
782
+ - 1.0 = среднее вращение
783
+ - 2.0 = сильное вращение
784
+ spiral_direction: направление (1.0 = по часовой, -1.0 = против)
785
+ interp_mode: режим интерполяции ('bilinear', 'bicubic', 'nearest')
786
+ debug: вывод отладочной информации
787
+
788
+ Returns:
789
+ torch.Tensor: трансформированный и padded тензор
790
+ """
791
+ b, c, h, w = input_tensor.shape
792
+ device = input_tensor.device
793
+ dtype = input_tensor.dtype
794
+
795
+ # V3.1.1 FIX: Получаем адаптивный epsilon для данного dtype
796
+ eps = get_adaptive_epsilon(dtype)
797
+
798
+ # V3.1: Валидация параметров
799
+ spiral_rotation = float(max(0.0, min(2.0, spiral_rotation)))
800
+ spiral_direction = 1.0 if spiral_direction >= 0 else -1.0
801
+ zoom_factor = float(max(-5.0, min(5.0, zoom_factor)))
802
+
803
+ if debug:
804
+ print(f"\n{'='*70}")
805
+ print(f"[Spiral Zoom V3.1.1]")
806
+ print(f" Input shape: {input_tensor.shape}")
807
+ print(f" Zoom Factor: {zoom_factor:.2f}")
808
+ print(f" Rotation: {spiral_rotation:.2f} ({'clockwise' if spiral_direction > 0 else 'counter-clockwise'})")
809
+ print(f" Interp mode: {interp_mode}")
810
+ print(f" Epsilon: {eps} (for {dtype})")
811
+ print(f"{'='*70}\n")
812
+
813
+ # Центр изображения
814
+ center_y = (h - 1) / 2.0
815
+ center_x = (w - 1) / 2.0
816
+
817
+ # Создаем координатные сетки
818
+ y_coords = torch.arange(h, device=device, dtype=dtype).view(-1, 1).expand(h, w)
819
+ x_coords = torch.arange(w, device=device, dtype=dtype).view(1, -1).expand(h, w)
820
+
821
+ # Смещение от центра
822
+ dy = y_coords - center_y
823
+ dx = x_coords - center_x
824
+
825
+ # V3.1.1 FIX: Полярные координаты с адаптивным epsilon
826
+ r = torch.sqrt(dx**2 + dy**2 + eps)
827
+ theta = torch.atan2(dy, dx)
828
+
829
+ # Спиральная трансформация
830
+ # 1. Zoom scale
831
+ zoom_scale = 1.0 + zoom_factor * 0.1
832
+
833
+ # 2. Rotation - зависит от расстояния от центра
834
+ max_radius = math.sqrt(h**2 + w**2) / 2.0
835
+ # V3.1.1 FIX: Адаптивный epsilon для division
836
+ normalized_r = torch.clamp(r / (max_radius + eps), 0.0, 1.0)
837
+
838
+ # Угол вращения увеличивается с расстоянием от центра (спиральный эффект)
839
+ rotation_angle = spiral_direction * spiral_rotation * normalized_r * math.pi
840
+
841
+ # 3. Применяем трансформацию
842
+ new_theta = theta + rotation_angle
843
+ new_r = r * zoom_scale
844
+
845
+ # Обратно в декартовы координаты
846
+ new_x = center_x + new_r * torch.cos(new_theta)
847
+ new_y = center_y + new_r * torch.sin(new_theta)
848
+
849
+ # Нормализация для grid_sample [-1, 1]
850
+ grid_x = 2.0 * new_x / max(w - 1, 1) - 1.0
851
+ grid_y = 2.0 * new_y / max(h - 1, 1) - 1.0
852
+
853
+ # V3.1 FIX: Clamp grid values для предотвращения выхода за границы
854
+ grid_x = torch.clamp(grid_x, -1.0, 1.0)
855
+ grid_y = torch.clamp(grid_y, -1.0, 1.0)
856
+
857
+ # Собираем grid
858
+ grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0).to(dtype)
859
+
860
+ # V3.1: Валидация размеров grid
861
+ expected_grid_shape = (1, h, w, 2)
862
+ if grid.shape != expected_grid_shape:
863
+ raise ValueError(f"Grid shape mismatch! Expected {expected_grid_shape}, got {grid.shape}")
864
+
865
+ # Применяем деформацию
866
+ warped = F.grid_sample(
867
+ input_tensor,
868
+ grid.expand(b, -1, -1, -1),
869
+ mode=interp_mode,
870
+ padding_mode='zeros',
871
+ align_corners=True
872
+ )
873
+
874
+ # V3.1: Проверка после warp
875
+ if warped.shape != input_tensor.shape:
876
+ raise ValueError(f"Warped shape mismatch! Expected {input_tensor.shape}, got {warped.shape}")
877
+
878
+ # Паддинг
879
+ padded = F.pad(warped, (pad_w, pad_w, pad_h, pad_h), mode='circular')
880
+
881
+ # V3.1: Финальная проверка размеров
882
+ expected_padded_shape = (b, c, h + 2*pad_h, w + 2*pad_w)
883
+ if padded.shape != expected_padded_shape:
884
+ raise ValueError(f"Padded shape mismatch! Expected {expected_padded_shape}, got {padded.shape}")
885
+
886
+ if debug:
887
+ print(f"[Spiral Zoom] Input shape: {input_tensor.shape}")
888
+ print(f"[Spiral Zoom] Output shape: {padded.shape}")
889
+ print(f"[Spiral Zoom] ✓ All shape checks passed")
890
+
891
+ return padded
892
+
893
+
894
+ # ═══════════════════════════════════════════════════════════════════════════
895
+ # 3.6. GRADIENT RADIAL BLENDING (V3.1 - НОВАЯ ФУНКЦИЯ)
896
+ # ═══════════════════════════════════════════════════════════════════════════
897
+
898
+ def apply_gradient_radial_blend(input_tensor, pad_h, pad_w,
899
+ gradient_center_x=0.5, gradient_center_y=0.5,
900
+ gradient_radius=1.0, debug=False):
901
+ """
902
+ Радиальный градиент для плавных переходов от центра к краям.
903
+
904
+ V3.1: НОВАЯ ФУНКЦИЯ
905
+
906
+ Args:
907
+ input_tensor: входной латент (B, C, H, W)
908
+ pad_h, pad_w: размеры паддинга
909
+ gradient_center_x: центр по X (0.0-1.0, default 0.5)
910
+ gradient_center_y: центр по Y (0.0-1.0, default 0.5)
911
+ gradient_radius: радиус градиента (0.1-2.0, default 1.0)
912
+ debug: вывод отладки
913
+
914
+ Returns:
915
+ torch.Tensor: padded тензор с радиальным градиентом
916
+ """
917
+ b, c, h, w = input_tensor.shape
918
+ device = input_tensor.device
919
+ dtype = input_tensor.dtype
920
+
921
+ # Валидация параметров
922
+ gradient_center_x = float(max(0.0, min(1.0, gradient_center_x)))
923
+ gradient_center_y = float(max(0.0, min(1.0, gradient_center_y)))
924
+ gradient_radius = float(max(0.1, min(2.0, gradient_radius)))
925
+
926
+ if debug:
927
+ print(f"[Gradient Radial] Center: ({gradient_center_x:.2f}, {gradient_center_y:.2f}), "
928
+ f"Radius: {gradient_radius:.2f}")
929
+
930
+ # Размеры с паддингом
931
+ canvas_h = h + 2 * pad_h
932
+ canvas_w = w + 2 * pad_w
933
+
934
+ # Координаты центра градиента
935
+ center_y = gradient_center_y * canvas_h
936
+ center_x = gradient_center_x * canvas_w
937
+
938
+ # Создаем координатную сетку
939
+ y = torch.arange(canvas_h, device=device, dtype=dtype).view(-1, 1)
940
+ x = torch.arange(canvas_w, device=device, dtype=dtype).view(1, -1)
941
+
942
+ # Расстояние от центра (нормализованное)
943
+ max_dist = math.sqrt(canvas_h**2 + canvas_w**2) / 2.0
944
+ dist = torch.sqrt((y - center_y)**2 + (x - center_x)**2) / max_dist
945
+
946
+ # Радиальный градиент [0, 1]
947
+ gradient = torch.clamp(1.0 - (dist / gradient_radius), 0.0, 1.0)
948
+ gradient = gradient.unsqueeze(0).unsqueeze(0) # (1, 1, canvas_h, canvas_w)
949
+
950
+ # Circular padding для входного тензора
951
+ padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular')
952
+
953
+ # V3.1.2 FIX: Проверяем размеры перед expand
954
+ expected_shape = (b, c, canvas_h, canvas_w)
955
+ if padded.shape != expected_shape:
956
+ raise RuntimeError(
957
+ f"[Gradient Radial] Padded shape mismatch! "
958
+ f"Expected {expected_shape}, got {padded.shape}"
959
+ )
960
+
961
+ # Применяем градиент (плавный переход к circular padding на краях)
962
+ # V3.1.2 FIX: Безопасный expand с проверкой
963
+ try:
964
+ gradient_expanded = gradient.expand(b, c, canvas_h, canvas_w)
965
+ except RuntimeError as e:
966
+ print(f"⚠️ [Gradient Radial] Expand failed: {e}")
967
+ print(f" Gradient shape: {gradient.shape}, Target: {expected_shape}")
968
+ # Fallback: используем broadcast_to
969
+ gradient_expanded = torch.broadcast_to(gradient, expected_shape)
970
+
971
+ result = padded * gradient_expanded
972
+
973
+ if debug:
974
+ print(f"[Gradient Radial] Gradient range: [{gradient.min().item():.3f}, {gradient.max().item():.3f}]")
975
+
976
+ return result
977
+
978
+
979
+ # ═══════════════════════════════════════════════════════════════════════════
980
+ # 3.7. NOISE BLEND (V3.1 - НОВАЯ ФУНКЦИЯ)
981
+ # ═══════════════════════════════════════════════════════════════════════════
982
+
983
+ def apply_noise_blend(input_tensor, pad_h, pad_w, noise_scale=5.0,
984
+ noise_octaves=2, debug=False):
985
+ """
986
+ Блендинг с процедурным шумом для органичных границ.
987
+
988
+ V3.1: НОВАЯ ФУНКЦИЯ
989
+ V3.1.1: КРИТИЧНЫЕ ИСПРАВЛЕНИЯ ДЛЯ FLOAT16
990
+ - Адаптивный epsilon для нормализации
991
+ - Упрощенная формула для blend (оптимизация)
992
+
993
+ Args:
994
+ input_tensor: входной латент (B, C, H, W)
995
+ pad_h, pad_w: размеры паддинга
996
+ noise_scale: масштаб шума (1.0-10.0, default 5.0)
997
+ noise_octaves: количество октав шума (1-4, default 2)
998
+ debug: вывод отладки
999
+
1000
+ Returns:
1001
+ torch.Tensor: padded тензор с noise blending
1002
+ """
1003
+ b, c, h, w = input_tensor.shape
1004
+ device = input_tensor.device
1005
+ dtype = input_tensor.dtype
1006
+
1007
+ # V3.1.1 FIX: Получаем адаптивный epsilon
1008
+ eps = get_adaptive_epsilon(dtype)
1009
+
1010
+ # Валидация параметров
1011
+ noise_scale = float(max(1.0, min(10.0, noise_scale)))
1012
+ noise_octaves = int(max(1, min(4, noise_octaves)))
1013
+
1014
+ if debug:
1015
+ print(f"[Noise Blend] Scale: {noise_scale:.2f}, Octaves: {noise_octaves}, Epsilon: {eps}")
1016
+
1017
+ # Размеры с паддингом
1018
+ canvas_h = h + 2 * pad_h
1019
+ canvas_w = w + 2 * pad_w
1020
+
1021
+ # Создаем координатную сетку
1022
+ y = torch.arange(canvas_h, device=device, dtype=dtype).view(-1, 1)
1023
+ x = torch.arange(canvas_w, device=device, dtype=dtype).view(1, -1)
1024
+
1025
+ # Многооктавный Perlin-style шум
1026
+ noise_mask = torch.zeros(canvas_h, canvas_w, device=device, dtype=dtype)
1027
+ amplitude = 1.0
1028
+ frequency = 1.0
1029
+
1030
+ for octave in range(noise_octaves):
1031
+ # Простой процедурный шум через sin/cos
1032
+ phase_x = x * frequency * noise_scale / canvas_w * 2 * math.pi
1033
+ phase_y = y * frequency * noise_scale / canvas_h * 2 * math.pi
1034
+
1035
+ octave_noise = torch.sin(phase_x + octave) * torch.cos(phase_y + octave * 0.7)
1036
+ noise_mask = noise_mask + octave_noise * amplitude
1037
+
1038
+ amplitude *= 0.5
1039
+ frequency *= 2.0
1040
+
1041
+ # V3.1.1 FIX: Нормализация в [0, 1] с адаптивным epsilon
1042
+ noise_mask = (noise_mask - noise_mask.min()) / (noise_mask.max() - noise_mask.min() + eps)
1043
+
1044
+ # Расстояние от контента (для комбинирования с шумом)
1045
+ content_box = (pad_h, pad_h + h, pad_w, pad_w + w)
1046
+ distance = create_distance_map(canvas_h, canvas_w, content_box, device, dtype)
1047
+
1048
+ # V3.1.1 FIX: Явное преобразование размеров для ясности
1049
+ distance_2d = distance.squeeze(0).squeeze(0) # (canvas_h, canvas_w)
1050
+
1051
+ # Комбинируем distance с noise (больше шума на краях)
1052
+ blend_mask = 0.7 * distance_2d + 0.3 * noise_mask
1053
+ blend_mask = torch.clamp(blend_mask, 0.0, 1.0).unsqueeze(0).unsqueeze(0)
1054
+
1055
+ # Circular padding
1056
+ padded = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular')
1057
+
1058
+ # V3.1.1 FIX: Упрощенная формула для blend
1059
+ # Было: padded * (1.0 - blend_mask) + padded * blend_mask * 0.5
1060
+ # Упрощено: padded * (1.0 - 0.5 * blend_mask)
1061
+ blend_mask_expanded = blend_mask.expand(b, c, canvas_h, canvas_w)
1062
+ result = padded * (1.0 - 0.5 * blend_mask_expanded)
1063
+
1064
+ if debug:
1065
+ print(f"[Noise Blend] Mask range: [{blend_mask.min().item():.3f}, {blend_mask.max().item():.3f}]")
1066
+
1067
+ return result
1068
+
1069
+
1070
+ # ═══════════════════════════════════════════════════════════════════════════
1071
+ # 4. ГЛАВНАЯ ФУНКЦИЯ (V3.0 - УЛУЧШЕНО)
1072
+ # ═══════════════════════════════════════════════════════════════════════════
1073
+
1074
+ def validate_zoom_params(params):
1075
+ """
1076
+ V3.0: Добавлены новые параметры
1077
+ """
1078
+ z_mode = params.get('zoom_mode', 'outpaint_zoom')
1079
+ try:
1080
+ zoom_mode = ZoomMode(z_mode)
1081
+ except:
1082
+ zoom_mode = ZoomMode.OUTPAINT_ZOOM
1083
+
1084
+ b_mode = params.get('blend_mode', 'circular_reflect')
1085
+ try:
1086
+ blend_mode = BlendMode(b_mode)
1087
+ except:
1088
+ blend_mode = BlendMode.CIRCULAR_REFLECT
1089
+
1090
+ return {
1091
+ # Базовые параметры
1092
+ 'zoom_factor': float(params.get('zoom_factor', 0.0)),
1093
+ 'zoom_mode': zoom_mode,
1094
+ 'blend_mode': blend_mode,
1095
+ 'convergence_point': float(params.get('convergence_point', 0.5)),
1096
+ 'convergence_y': float(params.get('convergence_y', 0.5)),
1097
+ 'depth_power': float(params.get('depth_power', 1.0)),
1098
+ 'blend_falloff': str(params.get('blend_falloff', 'smoothstep')),
1099
+ 'blend_sharpness': float(params.get('blend_sharpness', 1.0)),
1100
+ 'blend_width': params.get('blend_width', None),
1101
+ 'pan_x': float(params.get('x_pan', 0.0)),
1102
+ 'pan_y': float(params.get('y_pan', 0.0)),
1103
+ 'fade_to_black': bool(params.get('zoom_fade_to_black', False)),
1104
+ 'fade_strength': float(params.get('zoom_fade_strength', 0.3)),
1105
+ 'fade_edge_strength': float(params.get('fade_edge_strength', 0.15)),
1106
+
1107
+ # V3.0: Новые параметры
1108
+ 'noise_strength': float(params.get('noise_strength', 1.0)),
1109
+ 'interp_mode': str(params.get('interp_mode', 'bilinear')),
1110
+ 'zoom_in_fade': bool(params.get('zoom_in_fade', True)),
1111
+ 'variance_correction': bool(params.get('variance_correction', True)),
1112
+ 'auto_clamp_pan': bool(params.get('auto_clamp_pan', True)),
1113
+ 'adaptive_noise_scale': bool(params.get('adaptive_noise_scale', True)),
1114
+ 'debug': bool(params.get('debug_mode', False)),
1115
+
1116
+ # V3.1: Новые параметры для spiral_zoom
1117
+ 'spiral_rotation': float(params.get('spiral_rotation', 0.5)),
1118
+ 'spiral_direction': float(params.get('spiral_direction', 1.0)),
1119
+
1120
+ # V3.1: Новые параметры для gradient_radial
1121
+ 'gradient_center_x': float(params.get('gradient_center_x', 0.5)),
1122
+ 'gradient_center_y': float(params.get('gradient_center_y', 0.5)),
1123
+ 'gradient_radius': float(params.get('gradient_radius', 1.0)),
1124
+
1125
+ # V3.1: Новые параметры для noise_blend
1126
+ 'noise_scale': float(params.get('noise_scale', 5.0)),
1127
+ 'noise_octaves': int(params.get('noise_octaves', 2)),
1128
+ }
1129
+
1130
+
1131
+ def apply_unified_zoom(input_tensor, pad_h, pad_w, zoom_factor=0.0,
1132
+ zoom_mode=ZoomMode.OUTPAINT_ZOOM,
1133
+ blend_mode=BlendMode.CIRCULAR_REFLECT,
1134
+ convergence_point=0.5, convergence_y=0.5,
1135
+ depth_power=1.0,
1136
+ blend_falloff='smoothstep', blend_sharpness=1.0,
1137
+ blend_width=None,
1138
+ pan_x=0.0, pan_y=0.0,
1139
+ fade_to_black=False, fade_strength=0.3,
1140
+ fade_edge_strength=0.15,
1141
+ noise_strength=1.0,
1142
+ interp_mode='bilinear',
1143
+ zoom_in_fade=True,
1144
+ variance_correction=True,
1145
+ auto_clamp_pan=True,
1146
+ adaptive_noise_scale=True,
1147
+ debug=False,
1148
+ extra_params=None):
1149
+ """
1150
+ V3.5 - UNIFIED ZOOM (Фикс Grid Warp + Hybrid Mode)
1151
+ """
1152
+ is_active = (abs(zoom_factor) > 0.001) or (abs(pan_x) > 0.001) or (abs(pan_y) > 0.001)
1153
+ force_original_size = True
1154
+ if extra_params and 'force_original_size' in extra_params:
1155
+ force_original_size = bool(extra_params['force_original_size'])
1156
+
1157
+ target_h, target_w = input_tensor.shape[2], input_tensor.shape[3]
1158
+
1159
+ if not is_active:
1160
+ return input_tensor
1161
+
1162
+ result = None
1163
+
1164
+ if zoom_mode == ZoomMode.OUTPAINT_ZOOM:
1165
+ result = apply_outpaint_zoom(
1166
+ input_tensor, zoom_factor, pad_h, pad_w,
1167
+ convergence=convergence_point, convergence_y=convergence_y,
1168
+ fade_strength=fade_strength, depth_power=depth_power,
1169
+ pan_x=pan_x, pan_y=pan_y,
1170
+ fade_to_black=fade_to_black, fade_edge_strength=fade_edge_strength,
1171
+ blend_mode=blend_mode.value if isinstance(blend_mode, BlendMode) else str(blend_mode),
1172
+ noise_strength=noise_strength, interp_mode=interp_mode,
1173
+ zoom_in_fade=zoom_in_fade, variance_correction=variance_correction,
1174
+ auto_clamp_pan=auto_clamp_pan, adaptive_noise_scale=adaptive_noise_scale,
1175
+ debug=debug, extra_params=extra_params
1176
+ )
1177
+
1178
+ elif zoom_mode == ZoomMode.GRID_WARP:
1179
+ x_warped = apply_grid_warp_zoom(
1180
+ input_tensor, zoom_factor, convergence_point, depth_power,
1181
+ pan_x, pan_y, convergence_y, interp_mode=interp_mode, debug=debug
1182
+ )
1183
+ # FIX: mode='zeros' не существует. Используем constant + value=0
1184
+ result = F.pad(x_warped, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)
1185
+
1186
+ elif zoom_mode == ZoomMode.SPIRAL_ZOOM:
1187
+ s_rot = extra_params.get('spiral_rotation', 0.5) if extra_params else 0.5
1188
+ s_dir = extra_params.get('spiral_direction', 1.0) if extra_params else 1.0
1189
+ result = apply_spiral_zoom(
1190
+ input_tensor, zoom_factor, pad_h, pad_w,
1191
+ spiral_rotation=s_rot, spiral_direction=s_dir,
1192
+ interp_mode=interp_mode, debug=debug
1193
+ )
1194
+
1195
+ else: # Legacy
1196
+ x_shifted = apply_legacy_shift_zoom(
1197
+ input_tensor, zoom_factor, convergence_point, depth_power,
1198
+ pan_x, pan_y, auto_clamp_pan=auto_clamp_pan, debug=debug
1199
+ )
1200
+ result = F.pad(x_shifted, (pad_w, pad_w, pad_h, pad_h), mode='circular')
1201
+
1202
+ # ФИНАЛЬНЫЙ ФИКС РАЗМЕРА
1203
+ if force_original_size:
1204
+ if result.shape[2] != target_h or result.shape[3] != target_w:
1205
+ if debug: print(f"[Unified Zoom] Resizing back: {result.shape} -> {input_tensor.shape}")
1206
+ result = F.interpolate(result, size=(target_h, target_w), mode='bilinear', align_corners=True)
1207
+ else:
1208
+ if debug and (result.shape[2] != target_h or result.shape[3] != target_w):
1209
+ print(f"[Unified Zoom] Expanding Canvas: {input_tensor.shape} -> {result.shape}")
1210
+
1211
+ return result