Update sd-webui-chunk-weights/scripts/chunk_weighting.py

#1
by Dikz - opened
sd-webui-chunk-weights/scripts/chunk_weighting.py CHANGED
@@ -11,7 +11,7 @@ from modules.script_callbacks import on_app_started, on_script_unloaded
11
  from modules.ui_components import InputAccordion
12
 
13
  # ==============================================================================
14
- # ЧАСТЬ 1: ЛОГГЕР (Встроено, чтобы избежать ошибок импорта)
15
  # ==============================================================================
16
  class ColorCode:
17
  RESET = "\033[0m"
@@ -47,15 +47,12 @@ if not logger.handlers:
47
  # ЧАСТЬ 2: СКРИПТ (Логика весов)
48
  # ==============================================================================
49
 
50
- # Поиск целевого класса для патчинга.
51
- # Мы ищем базовый класс, от которого наследуются и SD1.5, и SDXL эмбеддеры.
52
  target_classes = []
53
 
54
  try:
55
  from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase
56
  target_classes.append(FrozenCLIPEmbedderWithCustomWordsBase)
57
  except ImportError:
58
- # Фолбэк для очень старых версий или если архитектура изменится
59
  logger.warning("Base class FrozenCLIPEmbedderWithCustomWordsBase not found. Trying individual classes.")
60
  try:
61
  from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
@@ -102,7 +99,6 @@ class ChunkWeight(scripts.Script):
102
  if not enable:
103
  return
104
 
105
- # Парсинг весов
106
  for v in weights.split(","):
107
  v = v.strip()
108
  if not v: continue
@@ -114,18 +110,14 @@ class ChunkWeight(scripts.Script):
114
 
115
  p.extra_generation_params["Chunk Weights"] = ", ".join(str(v) for v in WEIGHTS)
116
 
117
- # Сброс кэшей A1111. Это критически важно, иначе изменение весов
118
- # не применится к повторным генерациям с тем же промптом.
119
  p.cached_c = [None, None]
120
  p.cached_uc = [None, None]
121
 
122
- # Поддержка Hires Fix (Highres. fix кэширует свои кондишены отдельно)
123
  if hasattr(p, 'cached_hr_c'):
124
  p.cached_hr_c = [None, None]
125
  p.cached_hr_uc = [None, None]
126
 
127
  def postprocess(self, *args):
128
- # Очистка глобальных кэшей класса после генерации для надежности
129
  StableDiffusionProcessing.cached_c = [None, None]
130
  StableDiffusionProcessing.cached_uc = [None, None]
131
  StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
@@ -133,12 +125,10 @@ class ChunkWeight(scripts.Script):
133
 
134
 
135
  # ==============================================================================
136
- # ЧАСТЬ 3: ПАТЧИНГ (Внедрение в ядро)
137
  # ==============================================================================
138
 
139
  def patch_embedder(cls):
140
- """Патчит класс Embedder, добавляя умножение весов."""
141
-
142
  if hasattr(cls, '_chunk_weight_patched'):
143
  return
144
 
@@ -149,16 +139,11 @@ def patch_embedder(cls):
149
  def patched_process_texts(self, texts: List[str]):
150
  global IS_NEGATIVE_PROMPT, INDEX
151
 
152
- # A1111 передает флаг is_negative_prompt как атрибут списка texts.
153
- # Это работает благодаря классу SdConditioning в modules/prompt_parser.py
154
  if hasattr(texts, "is_negative_prompt"):
155
  IS_NEGATIVE_PROMPT = texts.is_negative_prompt
156
  else:
157
- # Если атрибут не найден (редкий случай), считаем это позитивным промптом
158
- # чтобы веса применились.
159
  IS_NEGATIVE_PROMPT = False
160
 
161
- # Сбрасываем индекс чанка перед обработкой нового текста
162
  INDEX = 0
163
  return original_process_texts(self, texts)
164
 
@@ -166,11 +151,9 @@ def patch_embedder(cls):
166
  def patched_process_tokens(self, remade_batch_tokens: list, batch_multipliers: list):
167
  global INDEX, WEIGHTS, IS_NEGATIVE_PROMPT
168
 
169
- # Если веса не заданы пользователем, работаем как оригинальный метод
170
  if not WEIGHTS:
171
  return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
172
 
173
- # Применяем веса только к позитивным промптам
174
  if INDEX >= 0 and not IS_NEGATIVE_PROMPT:
175
  batches: int = len(batch_multipliers)
176
 
@@ -179,31 +162,30 @@ def patch_embedder(cls):
179
  logger.debug(f"Applying weight {current_weight}x to Chunk {INDEX}")
180
 
181
  for b in range(batches):
182
- # batch_multipliers[b] - это список множителей (по умолчанию 1.0) для каждого токена
183
  for i in range(len(batch_multipliers[b])):
184
  batch_multipliers[b][i] *= current_weight
185
  else:
186
- # Если весов меньше, чем чанков, остальные получают 1.0 (стандарт)
187
  if not ChunkWeight._error_logged:
188
  logger.warning(f"Not enough weights provided! Chunk {INDEX} uses default weight 1.0.")
189
  ChunkWeight._error_logged = True
190
 
191
- # Увеличиваем индекс, так как process_tokens вызывается для каждого чанка (BREAK) по очереди
192
  INDEX += 1
193
 
194
  return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
195
 
196
- # Применяем подмену методов
197
  cls.process_texts = patched_process_texts
198
  cls.process_tokens = patched_process_tokens
199
  cls._chunk_weight_patched = True
200
 
201
- # Сохраняем оригиналы для возможности отключения
202
  cls._original_process_texts = original_process_texts
203
  cls._original_process_tokens = original_process_tokens
204
 
205
 
206
- def apply_patches():
 
 
 
 
207
  global PATCHED
208
  if PATCHED: return
209
 
@@ -237,6 +219,5 @@ def remove_patches():
237
  PATCHED = False
238
  logger.info("Chunk Weight patches removed.")
239
 
240
- # Регистрируем коллбэки для автозапуска
241
  on_app_started(apply_patches)
242
  on_script_unloaded(remove_patches)
 
11
  from modules.ui_components import InputAccordion
12
 
13
  # ==============================================================================
14
+ # ЧАСТЬ 1: ЛОГГЕР
15
  # ==============================================================================
16
  class ColorCode:
17
  RESET = "\033[0m"
 
47
  # ЧАСТЬ 2: СКРИПТ (Логика весов)
48
  # ==============================================================================
49
 
 
 
50
  target_classes = []
51
 
52
  try:
53
  from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase
54
  target_classes.append(FrozenCLIPEmbedderWithCustomWordsBase)
55
  except ImportError:
 
56
  logger.warning("Base class FrozenCLIPEmbedderWithCustomWordsBase not found. Trying individual classes.")
57
  try:
58
  from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
 
99
  if not enable:
100
  return
101
 
 
102
  for v in weights.split(","):
103
  v = v.strip()
104
  if not v: continue
 
110
 
111
  p.extra_generation_params["Chunk Weights"] = ", ".join(str(v) for v in WEIGHTS)
112
 
 
 
113
  p.cached_c = [None, None]
114
  p.cached_uc = [None, None]
115
 
 
116
  if hasattr(p, 'cached_hr_c'):
117
  p.cached_hr_c = [None, None]
118
  p.cached_hr_uc = [None, None]
119
 
120
  def postprocess(self, *args):
 
121
  StableDiffusionProcessing.cached_c = [None, None]
122
  StableDiffusionProcessing.cached_uc = [None, None]
123
  StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
 
125
 
126
 
127
  # ==============================================================================
128
+ # ЧАСТЬ 3: ПАТЧИНГ
129
  # ==============================================================================
130
 
131
  def patch_embedder(cls):
 
 
132
  if hasattr(cls, '_chunk_weight_patched'):
133
  return
134
 
 
139
  def patched_process_texts(self, texts: List[str]):
140
  global IS_NEGATIVE_PROMPT, INDEX
141
 
 
 
142
  if hasattr(texts, "is_negative_prompt"):
143
  IS_NEGATIVE_PROMPT = texts.is_negative_prompt
144
  else:
 
 
145
  IS_NEGATIVE_PROMPT = False
146
 
 
147
  INDEX = 0
148
  return original_process_texts(self, texts)
149
 
 
151
  def patched_process_tokens(self, remade_batch_tokens: list, batch_multipliers: list):
152
  global INDEX, WEIGHTS, IS_NEGATIVE_PROMPT
153
 
 
154
  if not WEIGHTS:
155
  return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
156
 
 
157
  if INDEX >= 0 and not IS_NEGATIVE_PROMPT:
158
  batches: int = len(batch_multipliers)
159
 
 
162
  logger.debug(f"Applying weight {current_weight}x to Chunk {INDEX}")
163
 
164
  for b in range(batches):
 
165
  for i in range(len(batch_multipliers[b])):
166
  batch_multipliers[b][i] *= current_weight
167
  else:
 
168
  if not ChunkWeight._error_logged:
169
  logger.warning(f"Not enough weights provided! Chunk {INDEX} uses default weight 1.0.")
170
  ChunkWeight._error_logged = True
171
 
 
172
  INDEX += 1
173
 
174
  return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
175
 
 
176
  cls.process_texts = patched_process_texts
177
  cls.process_tokens = patched_process_tokens
178
  cls._chunk_weight_patched = True
179
 
 
180
  cls._original_process_texts = original_process_texts
181
  cls._original_process_tokens = original_process_tokens
182
 
183
 
184
+ def apply_patches(*args, **kwargs):
185
+ """
186
+ Применяет патчи.
187
+ Принимает *args, **kwargs, так как A1111 передает (demo, app) в этот коллбек.
188
+ """
189
  global PATCHED
190
  if PATCHED: return
191
 
 
219
  PATCHED = False
220
  logger.info("Chunk Weight patches removed.")
221
 
 
222
  on_app_started(apply_patches)
223
  on_script_unloaded(remove_patches)