dikdimon commited on
Commit
80abb6e
·
verified ·
1 Parent(s): eaaf420

Upload sd-webui-chunk-weights using SD-Hub

Browse files
sd-webui-chunk-weights/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
sd-webui-chunk-weights/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Haoming
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
sd-webui-chunk-weights/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SD Forge Chunk Weights
2
+ This is an Extension for [Forge Classic](https://github.com/Haoming02/sd-webui-forge-classic), which allows you to control the weighting for each chunk of prompts *(**i.e.** every 75 tokens)*.
3
+
4
+ > [!Tip]
5
+ > In the WebUI, you can use the keyword **`BREAK`** to manually separate prompts into different chunks to group similar concepts together
6
+
7
+ ## How to Use
8
+ In the `Weighting` text field, enter a list of **comma-separated floats**, corresponding to the weights of each chunk in order *(the default weight is `1.0`)*
9
+
10
+ ## Examples
11
+
12
+ <table>
13
+ <tr align="center">
14
+ <td>
15
+ <img src="./example/1.jpg" width=384><br>
16
+ a photo of a dog, a house<br>
17
+ <b>Extension:</b> <code>Disabled</code>
18
+ </td>
19
+ <td>
20
+ <img src="./example/2.jpg" width=384><br>
21
+ a photo of a dog, BREAK, a house<br>
22
+ <b>Extension:</b> <code>Disabled</code>
23
+ </td>
24
+ </tr>
25
+ <tr align="center">
26
+ <td>
27
+ <img src="./example/3.jpg" width=384><br>
28
+ a photo of a dog, BREAK, a house<br>
29
+ <b>Weights:</b> <code>1.5, 1.0</code>
30
+ </td>
31
+ <td>
32
+ <img src="./example/4.jpg" width=384><br>
33
+ a photo of a dog, BREAK, a house<br>
34
+ <b>Weights:</b> <code>1.0, 1.5</code>
35
+ </td>
36
+ </tr>
37
+ </table>
38
+
39
+ <hr>
40
+
41
+ - Idea by. **[@jeanhadrien](https://github.com/jeanhadrien)** in [#89](https://github.com/Haoming02/sd-webui-forge-classic/issues/89), based on this [Extension](https://github.com/klimaleksus/stable-diffusion-webui-embedding-merge/)
sd-webui-chunk-weights/example/1.jpg ADDED

Git LFS Details

  • SHA256: ef650df9273ea36a09aa62de219f26ca0007f7facebc7ec8f002f3ae23eb6c2c
  • Pointer size: 131 Bytes
  • Size of remote file: 189 kB
sd-webui-chunk-weights/example/2.jpg ADDED

Git LFS Details

  • SHA256: 09e62f5b1764f78c85ec235ab5e863b63d0537e62469f92c3365c3c2195206a6
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
sd-webui-chunk-weights/example/3.jpg ADDED

Git LFS Details

  • SHA256: b7a3d14ad47ad552a15412fc94667780f145eeae7d4511a9dad48c770f24e072
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
sd-webui-chunk-weights/example/4.jpg ADDED

Git LFS Details

  • SHA256: aace8a20a06bcef678e5bca9e98c6f60d4edb92d58bc1e8156b26c676a0fc6fe
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
sd-webui-chunk-weights/scripts/__pycache__/cw_logger.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
sd-webui-chunk-weights/scripts/__pycache__/weighting.cpython-310.pyc ADDED
Binary file (5.89 kB). View file
 
sd-webui-chunk-weights/scripts/cw_logger.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+
5
+ class ColorCode:
6
+ RESET = "\033[0m"
7
+ BLACK = "\033[0;90m"
8
+ CYAN = "\033[0;36m"
9
+ YELLOW = "\033[0;33m"
10
+ RED = "\033[0;31m"
11
+
12
+ MAP = {
13
+ "DEBUG": BLACK,
14
+ "INFO": CYAN,
15
+ "WARNING": YELLOW,
16
+ "ERROR": RED,
17
+ }
18
+
19
+
20
+ class ColoredFormatter(logging.Formatter):
21
+ def format(self, record):
22
+ levelname = record.levelname
23
+ record.levelname = f"{ColorCode.MAP[levelname]}{levelname}{ColorCode.RESET}"
24
+ return super().format(record)
25
+
26
+
27
+ logger = logging.getLogger("ChunkWeight")
28
+ logger.setLevel(logging.INFO)
29
+ logger.propagate = False
30
+
31
+ if not logger.handlers:
32
+ handler = logging.StreamHandler(sys.stdout)
33
+ handler.setFormatter(ColoredFormatter("[%(name)s] %(levelname)s - %(message)s"))
34
+ logger.addHandler(handler)
sd-webui-chunk-weights/scripts/weighting.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import gradio as gr
4
+ from modules import scripts, shared
5
+ from modules.processing import StableDiffusionProcessing
6
+ from modules.infotext_utils import PasteField
7
+ from modules.ui_components import InputAccordion
8
+
9
+ # Настройка логгера
10
+ logger = logging.getLogger("ChunkWeight")
11
+ logger.setLevel(logging.INFO)
12
+
13
+ # ============================================================================
14
+ # ГЛОБАЛЬНОЕ СОСТОЯНИЕ
15
+ # ============================================================================
16
+ STATE = {
17
+ 'pos_weights': [],
18
+ 'neg_weights': [],
19
+ 'enabled': False,
20
+ 'original_method': None,
21
+ }
22
+
23
+ # ============================================================================
24
+ # ЛОГИКА ОБРАБОТКИ ТЕНЗОРОВ (FIXED FOR SDXL)
25
+ # ============================================================================
26
+
27
+ def apply_weight_to_cond(cond, weight):
28
+ """
29
+ Применяет вес к кондишену (Тензору или Словарю).
30
+ """
31
+ if weight == 1.0:
32
+ return cond
33
+
34
+ if isinstance(cond, dict):
35
+ # Логика для SDXL (Dict wrapper)
36
+ new_cond = cond.copy()
37
+
38
+ # Умножаем текстовые эмбеддинги (Cross-Attention)
39
+ for key in ['crossattn', 'c_crossattn', 'open_clip_projected']:
40
+ if key in new_cond:
41
+ new_cond[key] = new_cond[key] * weight
42
+
43
+ # Вектора стиля (pooled) 'vector' мы здесь НЕ умножаем на вес скалярно,
44
+ # так как это может сломать нормализацию. Их вес будет учтен при усреднении (merge_conds).
45
+ return new_cond
46
+
47
+ elif isinstance(cond, torch.Tensor):
48
+ # Логика для SD1.5 (Простой тензор)
49
+ return cond * weight
50
+
51
+ return cond
52
+
53
+
54
+ def merge_conds(cond_list, weights=None):
55
+ """
56
+ Склеивает список кондишенов чанков обратно в один промпт.
57
+ ИСПРАВЛЕНО: Корректная обработка SDXL Pooled Vectors.
58
+ """
59
+ if not cond_list:
60
+ return None
61
+
62
+ first = cond_list[0]
63
+
64
+ # --- Склейка для SDXL (Dictionary) ---
65
+ if isinstance(first, dict):
66
+ merged = {}
67
+ for key in first.keys():
68
+ tensors = [c[key] for c in cond_list if key in c]
69
+ if not tensors:
70
+ continue
71
+
72
+ # Проверяем размерность, чтобы понять, как склеивать
73
+ ndim = len(tensors[0].shape)
74
+
75
+ if ndim == 3:
76
+ # [Batch, Tokens, Dim] -> CrossAttention. Склеиваем последовательно (в длину).
77
+ merged[key] = torch.cat(tensors, dim=1)
78
+
79
+ elif ndim == 2:
80
+ # [Batch, Dim] -> Pooled Vector. Склеивать нельзя (ошибка mat1/mat2)!
81
+ # Нужно усреднить вектора всех чанков.
82
+
83
+ if weights and len(weights) == len(tensors):
84
+ # Взвешенное среднее: (V1*w1 + V2*w2) / (w1+w2)
85
+ # Это позволяет "весу чанка" влиять на глобальный стиль
86
+ stacked = torch.stack(tensors) # [N, B, D]
87
+
88
+ # Приводим веса к форме [N, 1, 1] для умножения
89
+ w_tensor = torch.tensor(weights, device=stacked.device, dtype=stacked.dtype).view(-1, 1, 1)
90
+
91
+ weighted_sum = (stacked * w_tensor).sum(dim=0) # [B, D]
92
+ total_weight = sum(weights) if sum(weights) != 0 else 1.0
93
+
94
+ merged[key] = weighted_sum / total_weight
95
+ else:
96
+ # Простое среднее, если весов нет
97
+ merged[key] = torch.stack(tensors).mean(dim=0)
98
+ else:
99
+ # Фолбэк для странных размерностей
100
+ merged[key] = tensors[0]
101
+
102
+ return merged
103
+
104
+ # --- Склейка для SD1.5 (Tensor) ---
105
+ elif isinstance(first, torch.Tensor):
106
+ # Здесь всегда [Batch, Tokens, Dim], просто склеиваем
107
+ return torch.cat(cond_list, dim=1)
108
+
109
+ return first
110
+
111
+
112
+ def patched_get_learned_conditioning(prompts):
113
+ """
114
+ Подмененный метод получения эмбеддингов.
115
+ """
116
+ global STATE
117
+ original_method = STATE['original_method']
118
+
119
+ # Фолбэк безопасности
120
+ if not STATE['enabled']:
121
+ return original_method(prompts)
122
+
123
+ if isinstance(prompts, str):
124
+ prompts = [prompts]
125
+
126
+ final_results = []
127
+
128
+ for i, prompt in enumerate(prompts):
129
+ # 1. Разбиваем по BREAK
130
+ chunks = prompt.split("BREAK")
131
+ chunk_tensors = []
132
+
133
+ # 2. Определяем веса для текущего промпта
134
+ # Эвристика: если кол-во весов совпадает с кол-вом чанков - используем их.
135
+ # Это позволяет отличить Pos от Neg промпта, если у них разное кол-во чанков.
136
+ current_weights = []
137
+
138
+ if len(STATE['pos_weights']) >= len(chunks) and len(STATE['pos_weights']) > 0:
139
+ current_weights = STATE['pos_weights'][:len(chunks)] # Берем ровно столько, сколько чанков
140
+ elif len(STATE['neg_weights']) >= len(chunks) and len(STATE['neg_weights']) > 0:
141
+ current_weights = STATE['neg_weights'][:len(chunks)]
142
+ else:
143
+ current_weights = [1.0] * len(chunks)
144
+
145
+ # 3. Обработка чанков
146
+ for idx, chunk_text in enumerate(chunks):
147
+ # Получаем эмбеддинг чанка (Оригинальный метод)
148
+ cond = original_method([chunk_text])
149
+
150
+ # Получаем вес
151
+ w = current_weights[idx]
152
+
153
+ # Применяем вес (только к crossattn)
154
+ if w != 1.0:
155
+ cond = apply_weight_to_cond(cond, w)
156
+
157
+ chunk_tensors.append(cond)
158
+
159
+ # 4. Склеиваем (с учетом весов для Pooled векторов)
160
+ merged = merge_conds(chunk_tensors, weights=current_weights)
161
+ final_results.append(merged)
162
+
163
+ # 5. Собираем итоговый батч (dim=0)
164
+ if len(final_results) > 1:
165
+ if isinstance(final_results[0], dict):
166
+ # Batching для SDXL словарей
167
+ batch_merged = {}
168
+ for key in final_results[0].keys():
169
+ batch_merged[key] = torch.cat([r[key] for r in final_results], dim=0)
170
+ return batch_merged
171
+ else:
172
+ # Batching для SD1.5 тензоров
173
+ return torch.cat(final_results, dim=0)
174
+ else:
175
+ return final_results[0]
176
+
177
+
178
+ # ============================================================================
179
+ # ИНТЕРФЕЙС
180
+ # ============================================================================
181
+
182
+ class ChunkWeightUltimateFixed(scripts.Script):
183
+ def title(self):
184
+ return "Chunk Weight (Ultimate SDXL Fix)"
185
+
186
+ def show(self, is_img2img):
187
+ return scripts.AlwaysVisible
188
+
189
+ def ui(self, is_img2img):
190
+ with InputAccordion(False, label="Chunk Weights") as enable:
191
+ gr.Markdown("Версия с исправленной поддержкой SDXL. Разбивает по `BREAK`.")
192
+ pos_weights = gr.Textbox(label="Positive Weights", placeholder="1.2, 0.8", lines=1)
193
+ neg_weights = gr.Textbox(label="Negative Weights", placeholder="1.0, 0.5", lines=1)
194
+
195
+ self.infotext_fields = [
196
+ PasteField(pos_weights, "ChunkW+"),
197
+ PasteField(neg_weights, "ChunkW-"),
198
+ ]
199
+ return [enable, pos_weights, neg_weights]
200
+
201
+ def process(self, p: StableDiffusionProcessing, enable: bool, pos_str: str, neg_str: str):
202
+ global STATE
203
+
204
+ self.remove_patch() # Очистка старых патчей
205
+
206
+ if not enable:
207
+ return
208
+
209
+ def parse(s):
210
+ try: return [float(x.strip()) for x in s.split(',') if x.strip()]
211
+ except: return []
212
+
213
+ STATE['pos_weights'] = parse(pos_str)
214
+ STATE['neg_weights'] = parse(neg_str)
215
+ STATE['enabled'] = True
216
+
217
+ if STATE['pos_weights']: p.extra_generation_params["ChunkW+"] = str(STATE['pos_weights'])
218
+ if STATE['neg_weights']: p.extra_generation_params["ChunkW-"] = str(STATE['neg_weights'])
219
+
220
+ # Патчинг get_learned_conditioning
221
+ if shared.sd_model and hasattr(shared.sd_model, 'get_learned_conditioning'):
222
+ logger.info("ChunkWeight: Patching model...")
223
+ STATE['original_method'] = shared.sd_model.get_learned_conditioning
224
+ shared.sd_model.get_learned_conditioning = patched_get_learned_conditioning
225
+
226
+ # Сброс кэшей A1111 (Обязательно!)
227
+ p.cached_c = [None, None]
228
+ p.cached_uc = [None, None]
229
+ p.cached_hr_c = [None, None]
230
+ p.cached_hr_uc = [None, None]
231
+
232
+ def postprocess(self, p, processed, *args):
233
+ self.remove_patch()
234
+
235
+ def remove_patch(self):
236
+ global STATE
237
+ if STATE['original_method'] and shared.sd_model:
238
+ shared.sd_model.get_learned_conditioning = STATE['original_method']
239
+ STATE['original_method'] = None
240
+ STATE['enabled'] = False
241
+
242
+ def on_unload():
243
+ global STATE
244
+ if STATE['original_method'] and shared.sd_model:
245
+ shared.sd_model.get_learned_conditioning = STATE['original_method']
246
+
247
+ scripts.script_callbacks.on_script_unloaded(on_unload)