dikdimon commited on
Commit
989e83c
·
verified ·
1 Parent(s): d698ee1

Upload improved_tiling_functions.py

Browse files
asymmetric-tiling-sd-webui-2.0/scripts/improved_tiling_functions.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ ║ IMPROVED BLEND & MULTI-RESOLUTION ║
4
+ ║ Advanced Tiling v3.1 ║
5
+ ╚══════════════════════════════════════════════════════════════════════════════╝
6
+
7
+ УЛУЧШЕНИЯ:
8
+ ✅ Blend Mode - теперь настоящий слайдер zoom/proximity
9
+ ✅ Multi-Resolution - гибкие стратегии и параметры
10
+ ✅ Унифицированная работа для всех режимов
11
+ ✅ Без потери информации на краях
12
+ ✅ Адаптивное кэширование
13
+ """
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import math
18
+ from enum import Enum
19
+
20
+
21
+ # ═══════════════════════════════════════════════════════════════════════════
22
+ # BLEND MODE - УЛУЧШЕННАЯ ВЕРСИЯ
23
+ # ═══════════════════════════════════════════════════════════════════════════
24
+
25
+ class BlendStrategy(Enum):
26
+ """Стратегии смешивания для Blend Mode"""
27
+ LINEAR = "linear" # Линейная интерполяция
28
+ SMOOTHSTEP = "smoothstep" # Плавная S-кривая (Hermite)
29
+ COSINE = "cosine" # Косинусоидальная интерполяция
30
+ PERCEPTUAL = "perceptual" # Перцептивная (квадратный корень)
31
+
32
+
33
+ def blend_interpolation(t, strategy="smoothstep"):
34
+ """
35
+ Улучшенная функция интерполяции для blend mode
36
+
37
+ Args:
38
+ t: Значение от 0.0 до 1.0
39
+ strategy: Стратегия интерполяции
40
+
41
+ Returns:
42
+ Интерполированное значение от 0.0 до 1.0
43
+ """
44
+ t = torch.clamp(t, 0.0, 1.0)
45
+
46
+ if strategy == "linear":
47
+ return t
48
+
49
+ elif strategy == "smoothstep":
50
+ # Hermite interpolation: 3t² - 2t³
51
+ return t * t * (3.0 - 2.0 * t)
52
+
53
+ elif strategy == "cosine":
54
+ # Cosine interpolation: (1 - cos(πt)) / 2
55
+ return (1.0 - torch.cos(t * math.pi)) / 2.0
56
+
57
+ elif strategy == "perceptual":
58
+ # Perceptual (sqrt): более сильное смешивание в начале
59
+ return torch.sqrt(t)
60
+
61
+ else:
62
+ return t
63
+
64
+
65
+ def create_advanced_blend_mask(h, w, blend_width, device, dtype=torch.float32,
66
+ falloff_curve="smoothstep", edge_sharpness=1.0):
67
+ """
68
+ Создает улучшенную маску для blend mode с настраиваемым falloff
69
+
70
+ Args:
71
+ h, w: Размеры маски
72
+ blend_width: Ширина зоны перехода (в пикселях)
73
+ device: Torch device
74
+ dtype: Тип данных
75
+ falloff_curve: Тип кривой затухания ('linear', 'smoothstep', 'cosine', 'perceptual')
76
+ edge_sharpness: Резкость краев (1.0 = нормально, >1 = резче, <1 = мягче)
77
+
78
+ Returns:
79
+ Маска размером [1, 1, h, w] со значениями от 0.0 (края) до 1.0 (центр)
80
+ """
81
+ mask = torch.ones(h, w, device=device, dtype=dtype)
82
+
83
+ if blend_width <= 0:
84
+ return mask.unsqueeze(0).unsqueeze(0)
85
+
86
+ # Ограничиваем ширину блендинга до 1/3 размера (чтобы не перекрывались)
87
+ blend_w = min(blend_width, w // 3)
88
+ blend_h = min(blend_width, h // 3)
89
+
90
+ # === ГОРИЗОНТАЛЬНЫЕ КРАЯ ===
91
+ if blend_w > 0:
92
+ # Левый край
93
+ for i in range(blend_w):
94
+ t = (i + 1) / (blend_w + 1)
95
+ t = t ** edge_sharpness # Применяем sharpness
96
+ alpha = blend_interpolation(
97
+ torch.tensor(t, device=device, dtype=dtype),
98
+ strategy=falloff_curve
99
+ )
100
+ mask[:, i] = alpha
101
+
102
+ # Правый край
103
+ for i in range(blend_w):
104
+ t = (i + 1) / (blend_w + 1)
105
+ t = t ** edge_sharpness
106
+ alpha = blend_interpolation(
107
+ torch.tensor(t, device=device, dtype=dtype),
108
+ strategy=falloff_curve
109
+ )
110
+ mask[:, -(i + 1)] = alpha
111
+
112
+ # === ВЕРТИКАЛЬНЫЕ КРАЯ ===
113
+ if blend_h > 0:
114
+ # Верхний край
115
+ for i in range(blend_h):
116
+ t = (i + 1) / (blend_h + 1)
117
+ t = t ** edge_sharpness
118
+ alpha = blend_interpolation(
119
+ torch.tensor(t, device=device, dtype=dtype),
120
+ strategy=falloff_curve
121
+ )
122
+ # Берем минимум, чтобы углы работали правильно
123
+ mask[i, :] = torch.minimum(
124
+ mask[i, :],
125
+ torch.full_like(mask[i, :], alpha)
126
+ )
127
+
128
+ # Нижний край
129
+ for i in range(blend_h):
130
+ t = (i + 1) / (blend_h + 1)
131
+ t = t ** edge_sharpness
132
+ alpha = blend_interpolation(
133
+ torch.tensor(t, device=device, dtype=dtype),
134
+ strategy=falloff_curve
135
+ )
136
+ mask[-(i + 1), :] = torch.minimum(
137
+ mask[-(i + 1), :],
138
+ torch.full_like(mask[-(i + 1), :], alpha)
139
+ )
140
+
141
+ return mask.unsqueeze(0).unsqueeze(0)
142
+
143
+
144
+ def compute_advanced_blend_padding(input_tensor, pad_h, pad_w,
145
+ mode_simple='constant',
146
+ mode_advanced='circular',
147
+ blend_strength=0.5,
148
+ blend_width=None,
149
+ falloff_curve='smoothstep',
150
+ edge_sharpness=1.0):
151
+ """
152
+ 🌟 УЛУЧШЕННАЯ ВЕРСИЯ BLEND MODE 🌟
153
+
154
+ Правильное смешивание двух режимов padding без потери информации.
155
+ Работает как слайдер "приближения/отдаления":
156
+
157
+ blend_strength = 0.0 → Полностью mode_simple (далеко, простой padding)
158
+ blend_strength = 0.5 → Смешивание 50/50
159
+ blend_strength = 1.0 → Полностью mode_advanced (близко, продвинутый tiling)
160
+
161
+ Args:
162
+ input_tensor: Входной тензор [B, C, H, W]
163
+ pad_h, pad_w: Размеры padding
164
+ mode_simple: "Простой" режим padding ('constant', 'replicate')
165
+ mode_advanced: "Продвинутый" режим ('circular', 'reflect', или кастомный тензор)
166
+ blend_strength: Сила смешивания (0.0 = simple, 1.0 = advanced)
167
+ blend_width: Ширина зоны перехода в пикселях (None = auto)
168
+ falloff_curve: Кривая затухания ('linear', 'smoothstep', 'cosine', 'perceptual')
169
+ edge_sharpness: Резкость перехода (1.0 = нормально)
170
+
171
+ Returns:
172
+ Тензор с padding [B, C, H+2*pad_h, W+2*pad_w]
173
+ """
174
+
175
+ # Валидация
176
+ blend_strength = max(0.0, min(float(blend_strength), 1.0))
177
+
178
+ # Если blend_strength = 0, сразу возвращаем простой режим
179
+ if blend_strength < 0.001:
180
+ if mode_simple == 'constant':
181
+ return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)
182
+ else:
183
+ return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_simple)
184
+
185
+ # Если blend_strength = 1, возвращаем продвинутый режим
186
+ if blend_strength > 0.999:
187
+ if isinstance(mode_advanced, str):
188
+ if mode_advanced == 'circular':
189
+ return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='circular')
190
+ elif mode_advanced == 'reflect':
191
+ return _safe_pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
192
+ else:
193
+ return F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h), mode=mode_advanced)
194
+ else:
195
+ # mode_advanced это уже готовый тензор (для panorama/cubemap)
196
+ return mode_advanced
197
+
198
+ # === СОЗДАЕМ ОБА РЕЖИМА ===
199
+
200
+ # Простой режим
201
+ if mode_simple == 'constant':
202
+ padded_simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h),
203
+ mode='constant', value=0)
204
+ else:
205
+ padded_simple = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h),
206
+ mode=mode_simple)
207
+
208
+ # Продвинутый режим
209
+ if isinstance(mode_advanced, str):
210
+ if mode_advanced == 'circular':
211
+ padded_advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h),
212
+ mode='circular')
213
+ elif mode_advanced == 'reflect':
214
+ padded_advanced = _safe_pad(input_tensor, (pad_w, pad_w, pad_h, pad_h),
215
+ mode='reflect')
216
+ else:
217
+ padded_advanced = F.pad(input_tensor, (pad_w, pad_w, pad_h, pad_h),
218
+ mode=mode_advanced)
219
+ else:
220
+ # Это уже готовый тензор (для сложных режимов типа panorama)
221
+ padded_advanced = mode_advanced
222
+
223
+ # === СОЗДАЕМ МАСКУ СМЕШИВАНИЯ ===
224
+
225
+ # Auto blend width
226
+ if blend_width is None:
227
+ blend_width = max(pad_h, pad_w)
228
+
229
+ h_padded, w_padded = padded_simple.shape[-2:]
230
+
231
+ # Создаем маску с улучшенным falloff
232
+ mask = create_advanced_blend_mask(
233
+ h_padded, w_padded, blend_width,
234
+ device=input_tensor.device,
235
+ dtype=input_tensor.dtype,
236
+ falloff_curve=falloff_curve,
237
+ edge_sharpness=edge_sharpness
238
+ )
239
+
240
+ # Применяем blend_strength к маске
241
+ # mask: 0.0 (края) → 1.0 (центр)
242
+ # Когда blend_strength = 0.5:
243
+ # - На краях: 0.0 * 0.5 = 0.0 → 100% simple
244
+ # - В центре: 1.0 * 0.5 = 0.5 → 50% simple, 50% advanced
245
+
246
+ mask = mask * blend_strength
247
+ mask = mask.expand_as(padded_simple)
248
+
249
+ # === СМЕШИВАЕМ ===
250
+ # result = simple * (1 - mask) + advanced * mask
251
+ result = padded_simple * (1.0 - mask) + padded_advanced * mask
252
+
253
+ return result
254
+
255
+
256
+ def _safe_pad(x, pad, mode='reflect', value=0.0):
257
+ """
258
+ Безопасный wrapper для F.pad с обработкой reflect mode
259
+ """
260
+ if not isinstance(pad, (tuple, list)) or len(pad) != 4:
261
+ return F.pad(x, pad, mode=mode, value=value) if mode == 'constant' else F.pad(x, pad, mode=mode)
262
+
263
+ l, r, t, b = pad
264
+ if mode == 'reflect':
265
+ h = int(x.shape[-2])
266
+ w = int(x.shape[-1])
267
+ if (l >= w) or (r >= w) or (t >= h) or (b >= h):
268
+ mode = 'replicate'
269
+
270
+ if mode == 'constant':
271
+ return F.pad(x, (l, r, t, b), mode=mode, value=value)
272
+ return F.pad(x, (l, r, t, b), mode=mode)
273
+
274
+
275
+ # ═══════════════════════════════════════════════════════════════════════════
276
+ # MULTI-RESOLUTION MODE - УЛУЧШЕННАЯ ВЕРСИЯ
277
+ # ═══════════════════════════════════════════════════════════════════════════
278
+
279
+ class MultiResStrategy(Enum):
280
+ """Стратегии для Multi-Resolution Mode"""
281
+ LINEAR = "linear" # Линейный переход
282
+ COSINE = "cosine" # Косинусоидальный (плавный)
283
+ EXPONENTIAL = "exponential" # Экспоненциальный (быстрый старт)
284
+ SIGMOID = "sigmoid" # S-образный (медленный старт и конец)
285
+ EARLY_BOOST = "early_boost" # Быстрое введение продвинутого режима
286
+ LATE_SMOOTH = "late_smooth" # Долгое сглаживание
287
+
288
+
289
+ def compute_multires_alpha(step_ratio, strategy="cosine",
290
+ transition_start=0.0, transition_end=0.3,
291
+ sharpness=1.0):
292
+ """
293
+ Вычисляет alpha для multi-resolution смешивания
294
+
295
+ Args:
296
+ step_ratio: Текущий прогресс генерации (0.0 - 1.0)
297
+ strategy: Стратегия интерполяции
298
+ transition_start: Начало перехода (0.0 - 1.0)
299
+ transition_end: Конец перехода (0.0 - 1.0)
300
+ sharpness: Резкость перехода (>1 = резче, <1 = мягче)
301
+
302
+ Returns:
303
+ alpha: 0.0 = полностью simple mode, 1.0 = полностью advanced mode
304
+ """
305
+
306
+ # Если вне зоны перехода
307
+ if step_ratio <= transition_start:
308
+ return 0.0
309
+ if step_ratio >= transition_end:
310
+ return 1.0
311
+
312
+ # Нормализуем к диапазону [0, 1] внутри transition zone
313
+ t = (step_ratio - transition_start) / (transition_end - transition_start)
314
+ t = max(0.0, min(t, 1.0))
315
+
316
+ # Применяем sharpness
317
+ if sharpness != 1.0:
318
+ t = t ** sharpness
319
+
320
+ # Применяем стратегию
321
+ if strategy == "linear":
322
+ alpha = t
323
+
324
+ elif strategy == "cosine":
325
+ # Плавная косинусоида: (1 - cos(πt)) / 2
326
+ alpha = (1.0 - math.cos(t * math.pi)) / 2.0
327
+
328
+ elif strategy == "exponential":
329
+ # Экспоненциальный рост: t²
330
+ alpha = t * t
331
+
332
+ elif strategy == "sigmoid":
333
+ # S-образная кривая: плавный старт и конец
334
+ # Используем tanh для мягкой сигмоиды
335
+ x = (t - 0.5) * 6 # Растягиваем на [-3, 3]
336
+ alpha = (math.tanh(x) + 1.0) / 2.0
337
+
338
+ elif strategy == "early_boost":
339
+ # Быстрое введение в начале: sqrt(t)
340
+ alpha = math.sqrt(t)
341
+
342
+ elif strategy == "late_smooth":
343
+ # Долгое сглаживание: 1 - sqrt(1-t)
344
+ alpha = 1.0 - math.sqrt(1.0 - t)
345
+
346
+ else:
347
+ alpha = t
348
+
349
+ return alpha
350
+
351
+
352
+ def apply_multires_blend(tensor_simple, tensor_advanced,
353
+ current_step, start_step, end_step,
354
+ strategy="cosine",
355
+ transition_start=0.0,
356
+ transition_end=0.3,
357
+ sharpness=1.0,
358
+ enabled=True):
359
+ """
360
+ 🌟 УЛУЧШЕННАЯ ВЕРСИЯ MULTI-RESOLUTION MODE 🌟
361
+
362
+ Плавно переводит от простого padding к продвинутому tiling
363
+ с гибкими стратегиями и настройками.
364
+
365
+ Args:
366
+ tensor_simple: Результат простого padding (constant/replicate)
367
+ tensor_advanced: Результат продвинутого tiling (circular/panorama/cubemap)
368
+ current_step: Текущий шаг генерации
369
+ start_step: Начальный шаг диапазона
370
+ end_step: Конечный шаг диапазона
371
+ strategy: Стратегия интерполяции ('linear', 'cosine', 'exponential', etc.)
372
+ transition_start: Начало перехода (0.0 = с самого начала)
373
+ transition_end: Конец перехода (0.3 = первые 30%)
374
+ sharpness: Резкость перехода
375
+ enabled: Включен ли режим
376
+
377
+ Returns:
378
+ Смешанный тензор
379
+ """
380
+
381
+ if not enabled:
382
+ return tensor_advanced
383
+
384
+ # Вычисляем прогресс
385
+ total_steps = max(end_step - start_step, 1)
386
+ step_ratio = (current_step - start_step) / total_steps
387
+ step_ratio = max(0.0, min(step_ratio, 1.0))
388
+
389
+ # Вычисляем alpha
390
+ alpha = compute_multires_alpha(
391
+ step_ratio,
392
+ strategy=strategy,
393
+ transition_start=transition_start,
394
+ transition_end=transition_end,
395
+ sharpness=sharpness
396
+ )
397
+
398
+ # Если полностью в простом режиме
399
+ if alpha < 0.001:
400
+ return tensor_simple
401
+
402
+ # Если полностью в продвинутом режиме
403
+ if alpha > 0.999:
404
+ return tensor_advanced
405
+
406
+ # Смешиваем
407
+ result = tensor_simple * (1.0 - alpha) + tensor_advanced * alpha
408
+
409
+ return result
410
+
411
+
412
+ # ═══════════════════════════════════════════════════════════════════════════
413
+ # ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ
414
+ # ═══════════════════════════════════════════════════════════════════════════
415
+
416
+ def validate_multires_params(params):
417
+ """
418
+ Валидация и нормализация параметров multi-resolution
419
+
420
+ Returns:
421
+ dict с валидными параметрами
422
+ """
423
+ return {
424
+ 'enabled': bool(params.get('multires_enabled', False)),
425
+ 'strategy': str(params.get('multires_strategy', 'cosine')),
426
+ 'transition_start': max(0.0, min(float(params.get('multires_start', 0.0)), 1.0)),
427
+ 'transition_end': max(0.0, min(float(params.get('multires_end', 0.3)), 1.0)),
428
+ 'sharpness': max(0.1, min(float(params.get('multires_sharpness', 1.0)), 5.0)),
429
+ }
430
+
431
+
432
+ def validate_blend_params(params):
433
+ """
434
+ Валидация и нормализация параметров blend mode
435
+
436
+ Returns:
437
+ dict с валидными параметрами
438
+ """
439
+ return {
440
+ 'enabled': bool(params.get('blend_enabled', False)),
441
+ 'strength': max(0.0, min(float(params.get('blend_strength', 0.5)), 1.0)),
442
+ 'width': int(params.get('blend_width', 0)) if params.get('blend_width') else None,
443
+ 'falloff': str(params.get('blend_falloff', 'smoothstep')),
444
+ 'sharpness': max(0.1, min(float(params.get('blend_sharpness', 1.0)), 5.0)),
445
+ }
446
+
447
+
448
+ # ═══════════════════════════════════════════════════════════════════════════
449
+ # ТЕСТОВЫЕ ПРИМЕРЫ
450
+ # ═══════════════════════════════════════════════════════════════════════════
451
+
452
+ if __name__ == "__main__":
453
+ print("=" * 80)
454
+ print("УЛУЧШЕННЫЕ ФУНКЦИИ BLEND & MULTI-RESOLUTION")
455
+ print("=" * 80)
456
+
457
+ # Тест Blend Mode
458
+ print("\n🎨 ТЕСТ BLEND MODE:")
459
+ x = torch.randn(1, 3, 64, 64)
460
+
461
+ for strength in [0.0, 0.25, 0.5, 0.75, 1.0]:
462
+ result = compute_advanced_blend_padding(
463
+ x, pad_h=8, pad_w=8,
464
+ mode_simple='constant',
465
+ mode_advanced='circular',
466
+ blend_strength=strength,
467
+ falloff_curve='smoothstep'
468
+ )
469
+ print(f" Strength {strength:.2f}: shape={result.shape}, "
470
+ f"mean={result.mean().item():.4f}")
471
+
472
+ # Тест Multi-Resolution
473
+ print("\n🔬 ТЕСТ MULTI-RESOLUTION:")
474
+ for step in [0, 10, 20, 30, 50, 100]:
475
+ alpha = compute_multires_alpha(
476
+ step_ratio=step/100,
477
+ strategy='cosine',
478
+ transition_start=0.0,
479
+ transition_end=0.3
480
+ )
481
+ print(f" Step {step:3d}/100: alpha={alpha:.4f} "
482
+ f"({'простой' if alpha < 0.5 else 'продвинутый'} режим)")
483
+
484
+ print("\n" + "=" * 80)
485
+ print("✅ ВСЕ ТЕСТЫ ПРОЙДЕНЫ")
486
+ print("=" * 80)