File size: 10,133 Bytes
80abb6e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | import logging
import torch
import gradio as gr
from modules import scripts, shared
from modules.processing import StableDiffusionProcessing
from modules.infotext_utils import PasteField
from modules.ui_components import InputAccordion
# Настройка логгера
logger = logging.getLogger("ChunkWeight")
logger.setLevel(logging.INFO)
# ============================================================================
# ГЛОБАЛЬНОЕ СОСТОЯНИЕ
# ============================================================================
STATE = {
'pos_weights': [],
'neg_weights': [],
'enabled': False,
'original_method': None,
}
# ============================================================================
# ЛОГИКА ОБРАБОТКИ ТЕНЗОРОВ (FIXED FOR SDXL)
# ============================================================================
def apply_weight_to_cond(cond, weight):
"""
Применяет вес к кондишену (Тензору или Словарю).
"""
if weight == 1.0:
return cond
if isinstance(cond, dict):
# Логика для SDXL (Dict wrapper)
new_cond = cond.copy()
# Умножаем текстовые эмбеддинги (Cross-Attention)
for key in ['crossattn', 'c_crossattn', 'open_clip_projected']:
if key in new_cond:
new_cond[key] = new_cond[key] * weight
# Вектора стиля (pooled) 'vector' мы здесь НЕ умножаем на вес скалярно,
# так как это может сломать нормализацию. Их вес будет учтен при усреднении (merge_conds).
return new_cond
elif isinstance(cond, torch.Tensor):
# Логика для SD1.5 (Простой тензор)
return cond * weight
return cond
def merge_conds(cond_list, weights=None):
"""
Склеивает список кондишенов чанков обратно в один промпт.
ИСПРАВЛЕНО: Корректная обработка SDXL Pooled Vectors.
"""
if not cond_list:
return None
first = cond_list[0]
# --- Склейка для SDXL (Dictionary) ---
if isinstance(first, dict):
merged = {}
for key in first.keys():
tensors = [c[key] for c in cond_list if key in c]
if not tensors:
continue
# Проверяем размерность, чтобы понять, как склеивать
ndim = len(tensors[0].shape)
if ndim == 3:
# [Batch, Tokens, Dim] -> CrossAttention. Склеиваем последовательно (в длину).
merged[key] = torch.cat(tensors, dim=1)
elif ndim == 2:
# [Batch, Dim] -> Pooled Vector. Склеивать нельзя (ошибка mat1/mat2)!
# Нужно усреднить вектора всех чанков.
if weights and len(weights) == len(tensors):
# Взвешенное среднее: (V1*w1 + V2*w2) / (w1+w2)
# Это позволяет "весу чанка" влиять на глобальный стиль
stacked = torch.stack(tensors) # [N, B, D]
# Приводим веса к форме [N, 1, 1] для умножения
w_tensor = torch.tensor(weights, device=stacked.device, dtype=stacked.dtype).view(-1, 1, 1)
weighted_sum = (stacked * w_tensor).sum(dim=0) # [B, D]
total_weight = sum(weights) if sum(weights) != 0 else 1.0
merged[key] = weighted_sum / total_weight
else:
# Простое среднее, если весов нет
merged[key] = torch.stack(tensors).mean(dim=0)
else:
# Фолбэк для странных размерностей
merged[key] = tensors[0]
return merged
# --- Склейка для SD1.5 (Tensor) ---
elif isinstance(first, torch.Tensor):
# Здесь всегда [Batch, Tokens, Dim], просто склеиваем
return torch.cat(cond_list, dim=1)
return first
def patched_get_learned_conditioning(prompts):
"""
Подмененный метод получения эмбеддингов.
"""
global STATE
original_method = STATE['original_method']
# Фолбэк безопасности
if not STATE['enabled']:
return original_method(prompts)
if isinstance(prompts, str):
prompts = [prompts]
final_results = []
for i, prompt in enumerate(prompts):
# 1. Разбиваем по BREAK
chunks = prompt.split("BREAK")
chunk_tensors = []
# 2. Определяем веса для текущего промпта
# Эвристика: если кол-во весов совпадает с кол-вом чанков - используем их.
# Это позволяет отличить Pos от Neg промпта, если у них разное кол-во чанков.
current_weights = []
if len(STATE['pos_weights']) >= len(chunks) and len(STATE['pos_weights']) > 0:
current_weights = STATE['pos_weights'][:len(chunks)] # Берем ровно столько, сколько чанков
elif len(STATE['neg_weights']) >= len(chunks) and len(STATE['neg_weights']) > 0:
current_weights = STATE['neg_weights'][:len(chunks)]
else:
current_weights = [1.0] * len(chunks)
# 3. Обработка чанков
for idx, chunk_text in enumerate(chunks):
# Получаем эмбеддинг чанка (Оригинальный метод)
cond = original_method([chunk_text])
# Получаем вес
w = current_weights[idx]
# Применяем вес (только к crossattn)
if w != 1.0:
cond = apply_weight_to_cond(cond, w)
chunk_tensors.append(cond)
# 4. Склеиваем (с учетом весов для Pooled векторов)
merged = merge_conds(chunk_tensors, weights=current_weights)
final_results.append(merged)
# 5. Собираем итоговый батч (dim=0)
if len(final_results) > 1:
if isinstance(final_results[0], dict):
# Batching для SDXL словарей
batch_merged = {}
for key in final_results[0].keys():
batch_merged[key] = torch.cat([r[key] for r in final_results], dim=0)
return batch_merged
else:
# Batching для SD1.5 тензоров
return torch.cat(final_results, dim=0)
else:
return final_results[0]
# ============================================================================
# ИНТЕРФЕЙС
# ============================================================================
class ChunkWeightUltimateFixed(scripts.Script):
def title(self):
return "Chunk Weight (Ultimate SDXL Fix)"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label="Chunk Weights") as enable:
gr.Markdown("Версия с исправленной поддержкой SDXL. Разбивает по `BREAK`.")
pos_weights = gr.Textbox(label="Positive Weights", placeholder="1.2, 0.8", lines=1)
neg_weights = gr.Textbox(label="Negative Weights", placeholder="1.0, 0.5", lines=1)
self.infotext_fields = [
PasteField(pos_weights, "ChunkW+"),
PasteField(neg_weights, "ChunkW-"),
]
return [enable, pos_weights, neg_weights]
def process(self, p: StableDiffusionProcessing, enable: bool, pos_str: str, neg_str: str):
global STATE
self.remove_patch() # Очистка старых патчей
if not enable:
return
def parse(s):
try: return [float(x.strip()) for x in s.split(',') if x.strip()]
except: return []
STATE['pos_weights'] = parse(pos_str)
STATE['neg_weights'] = parse(neg_str)
STATE['enabled'] = True
if STATE['pos_weights']: p.extra_generation_params["ChunkW+"] = str(STATE['pos_weights'])
if STATE['neg_weights']: p.extra_generation_params["ChunkW-"] = str(STATE['neg_weights'])
# Патчинг get_learned_conditioning
if shared.sd_model and hasattr(shared.sd_model, 'get_learned_conditioning'):
logger.info("ChunkWeight: Patching model...")
STATE['original_method'] = shared.sd_model.get_learned_conditioning
shared.sd_model.get_learned_conditioning = patched_get_learned_conditioning
# Сброс кэшей A1111 (Обязательно!)
p.cached_c = [None, None]
p.cached_uc = [None, None]
p.cached_hr_c = [None, None]
p.cached_hr_uc = [None, None]
def postprocess(self, p, processed, *args):
self.remove_patch()
def remove_patch(self):
global STATE
if STATE['original_method'] and shared.sd_model:
shared.sd_model.get_learned_conditioning = STATE['original_method']
STATE['original_method'] = None
STATE['enabled'] = False
def on_unload():
global STATE
if STATE['original_method'] and shared.sd_model:
shared.sd_model.get_learned_conditioning = STATE['original_method']
scripts.script_callbacks.on_script_unloaded(on_unload) |