File size: 10,133 Bytes
80abb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import logging
import torch
import gradio as gr
from modules import scripts, shared
from modules.processing import StableDiffusionProcessing
from modules.infotext_utils import PasteField
from modules.ui_components import InputAccordion

# Настройка логгера
logger = logging.getLogger("ChunkWeight")
logger.setLevel(logging.INFO)

# ============================================================================
# ГЛОБАЛЬНОЕ СОСТОЯНИЕ
# ============================================================================
STATE = {
    'pos_weights': [],
    'neg_weights': [],
    'enabled': False,
    'original_method': None, 
}

# ============================================================================
# ЛОГИКА ОБРАБОТКИ ТЕНЗОРОВ (FIXED FOR SDXL)
# ============================================================================

def apply_weight_to_cond(cond, weight):
    """
    Применяет вес к кондишену (Тензору или Словарю).
    """
    if weight == 1.0:
        return cond

    if isinstance(cond, dict):
        # Логика для SDXL (Dict wrapper)
        new_cond = cond.copy()
        
        # Умножаем текстовые эмбеддинги (Cross-Attention)
        for key in ['crossattn', 'c_crossattn', 'open_clip_projected']:
            if key in new_cond:
                new_cond[key] = new_cond[key] * weight
                
        # Вектора стиля (pooled) 'vector' мы здесь НЕ умножаем на вес скалярно,
        # так как это может сломать нормализацию. Их вес будет учтен при усреднении (merge_conds).
        return new_cond
    
    elif isinstance(cond, torch.Tensor):
        # Логика для SD1.5 (Простой тензор)
        return cond * weight
    
    return cond


def merge_conds(cond_list, weights=None):
    """
    Склеивает список кондишенов чанков обратно в один промпт.
    ИСПРАВЛЕНО: Корректная обработка SDXL Pooled Vectors.
    """
    if not cond_list:
        return None
        
    first = cond_list[0]
    
    # --- Склейка для SDXL (Dictionary) ---
    if isinstance(first, dict):
        merged = {}
        for key in first.keys():
            tensors = [c[key] for c in cond_list if key in c]
            if not tensors:
                continue
                
            # Проверяем размерность, чтобы понять, как склеивать
            ndim = len(tensors[0].shape)
            
            if ndim == 3: 
                # [Batch, Tokens, Dim] -> CrossAttention. Склеиваем последовательно (в длину).
                merged[key] = torch.cat(tensors, dim=1)
                
            elif ndim == 2:
                # [Batch, Dim] -> Pooled Vector. Склеивать нельзя (ошибка mat1/mat2)!
                # Нужно усреднить вектора всех чанков.
                
                if weights and len(weights) == len(tensors):
                    # Взвешенное среднее: (V1*w1 + V2*w2) / (w1+w2)
                    # Это позволяет "весу чанка" влиять на глобальный стиль
                    stacked = torch.stack(tensors) # [N, B, D]
                    
                    # Приводим веса к форме [N, 1, 1] для умножения
                    w_tensor = torch.tensor(weights, device=stacked.device, dtype=stacked.dtype).view(-1, 1, 1)
                    
                    weighted_sum = (stacked * w_tensor).sum(dim=0) # [B, D]
                    total_weight = sum(weights) if sum(weights) != 0 else 1.0
                    
                    merged[key] = weighted_sum / total_weight
                else:
                    # Простое среднее, если весов нет
                    merged[key] = torch.stack(tensors).mean(dim=0)
            else:
                # Фолбэк для странных размерностей
                merged[key] = tensors[0]
                
        return merged

    # --- Склейка для SD1.5 (Tensor) ---
    elif isinstance(first, torch.Tensor):
        # Здесь всегда [Batch, Tokens, Dim], просто склеиваем
        return torch.cat(cond_list, dim=1)
    
    return first


def patched_get_learned_conditioning(prompts):
    """
    Подмененный метод получения эмбеддингов.
    """
    global STATE
    original_method = STATE['original_method']
    
    # Фолбэк безопасности
    if not STATE['enabled']:
        return original_method(prompts)

    if isinstance(prompts, str):
        prompts = [prompts]

    final_results = []

    for i, prompt in enumerate(prompts):
        # 1. Разбиваем по BREAK
        chunks = prompt.split("BREAK")
        chunk_tensors = []
        
        # 2. Определяем веса для текущего промпта
        # Эвристика: если кол-во весов совпадает с кол-вом чанков - используем их.
        # Это позволяет отличить Pos от Neg промпта, если у них разное кол-во чанков.
        current_weights = []
        
        if len(STATE['pos_weights']) >= len(chunks) and len(STATE['pos_weights']) > 0:
             current_weights = STATE['pos_weights'][:len(chunks)] # Берем ровно столько, сколько чанков
        elif len(STATE['neg_weights']) >= len(chunks) and len(STATE['neg_weights']) > 0:
             current_weights = STATE['neg_weights'][:len(chunks)]
        else:
             current_weights = [1.0] * len(chunks)

        # 3. Обработка чанков
        for idx, chunk_text in enumerate(chunks):
            # Получаем эмбеддинг чанка (Оригинальный метод)
            cond = original_method([chunk_text])
            
            # Получаем вес
            w = current_weights[idx]
            
            # Применяем вес (только к crossattn)
            if w != 1.0:
                cond = apply_weight_to_cond(cond, w)
            
            chunk_tensors.append(cond)

        # 4. Склеиваем (с учетом весов для Pooled векторов)
        merged = merge_conds(chunk_tensors, weights=current_weights)
        final_results.append(merged)

    # 5. Собираем итоговый батч (dim=0)
    if len(final_results) > 1:
        if isinstance(final_results[0], dict):
            # Batching для SDXL словарей
            batch_merged = {}
            for key in final_results[0].keys():
                batch_merged[key] = torch.cat([r[key] for r in final_results], dim=0)
            return batch_merged
        else:
            # Batching для SD1.5 тензоров
            return torch.cat(final_results, dim=0)
    else:
        return final_results[0]


# ============================================================================
# ИНТЕРФЕЙС
# ============================================================================

class ChunkWeightUltimateFixed(scripts.Script):
    def title(self):
        return "Chunk Weight (Ultimate SDXL Fix)"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        with InputAccordion(False, label="Chunk Weights") as enable:
            gr.Markdown("Версия с исправленной поддержкой SDXL. Разбивает по `BREAK`.")
            pos_weights = gr.Textbox(label="Positive Weights", placeholder="1.2, 0.8", lines=1)
            neg_weights = gr.Textbox(label="Negative Weights", placeholder="1.0, 0.5", lines=1)

        self.infotext_fields = [
            PasteField(pos_weights, "ChunkW+"),
            PasteField(neg_weights, "ChunkW-"),
        ]
        return [enable, pos_weights, neg_weights]

    def process(self, p: StableDiffusionProcessing, enable: bool, pos_str: str, neg_str: str):
        global STATE
        
        self.remove_patch() # Очистка старых патчей
        
        if not enable:
            return

        def parse(s):
            try: return [float(x.strip()) for x in s.split(',') if x.strip()]
            except: return []

        STATE['pos_weights'] = parse(pos_str)
        STATE['neg_weights'] = parse(neg_str)
        STATE['enabled'] = True
        
        if STATE['pos_weights']: p.extra_generation_params["ChunkW+"] = str(STATE['pos_weights'])
        if STATE['neg_weights']: p.extra_generation_params["ChunkW-"] = str(STATE['neg_weights'])

        # Патчинг get_learned_conditioning
        if shared.sd_model and hasattr(shared.sd_model, 'get_learned_conditioning'):
            logger.info("ChunkWeight: Patching model...")
            STATE['original_method'] = shared.sd_model.get_learned_conditioning
            shared.sd_model.get_learned_conditioning = patched_get_learned_conditioning
            
            # Сброс кэшей A1111 (Обязательно!)
            p.cached_c = [None, None]
            p.cached_uc = [None, None]
            p.cached_hr_c = [None, None]
            p.cached_hr_uc = [None, None]

    def postprocess(self, p, processed, *args):
        self.remove_patch()

    def remove_patch(self):
        global STATE
        if STATE['original_method'] and shared.sd_model:
            shared.sd_model.get_learned_conditioning = STATE['original_method']
            STATE['original_method'] = None
            STATE['enabled'] = False

def on_unload():
    global STATE
    if STATE['original_method'] and shared.sd_model:
        shared.sd_model.get_learned_conditioning = STATE['original_method']

scripts.script_callbacks.on_script_unloaded(on_unload)