sdas / sd-webui-chunk-weights /scripts /chunk_weighting.py
Dikz's picture
Update sd-webui-chunk-weights/scripts/chunk_weighting.py
8fdfd44 verified
raw
history blame
7.45 kB
import logging
import sys
from functools import wraps
from typing import List
import gradio as gr
from modules import scripts
from modules.infotext_utils import PasteField
from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img
from modules.script_callbacks import on_app_started, on_script_unloaded
from modules.ui_components import InputAccordion
# ==============================================================================
# ЧАСТЬ 1: ЛОГГЕР
# ==============================================================================
class ColorCode:
RESET = "\033[0m"
BLACK = "\033[0;90m"
CYAN = "\033[0;36m"
YELLOW = "\033[0;33m"
RED = "\033[0;31m"
MAP = {
"DEBUG": BLACK,
"INFO": CYAN,
"WARNING": YELLOW,
"ERROR": RED,
}
class ColoredFormatter(logging.Formatter):
def format(self, record):
levelname = record.levelname
if levelname in ColorCode.MAP:
record.levelname = f"{ColorCode.MAP[levelname]}{levelname}{ColorCode.RESET}"
return super().format(record)
logger = logging.getLogger("ChunkWeight")
logger.setLevel(logging.INFO)
logger.propagate = False
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(ColoredFormatter("[%(name)s] %(levelname)s - %(message)s"))
logger.addHandler(handler)
# ==============================================================================
# ЧАСТЬ 2: СКРИПТ (Логика весов)
# ==============================================================================
target_classes = []
try:
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase
target_classes.append(FrozenCLIPEmbedderWithCustomWordsBase)
except ImportError:
logger.warning("Base class FrozenCLIPEmbedderWithCustomWordsBase not found. Trying individual classes.")
try:
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
target_classes.append(FrozenCLIPEmbedderWithCustomWords)
except ImportError: pass
try:
from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords
target_classes.append(FrozenOpenCLIPEmbedderWithCustomWords)
except ImportError: pass
IS_NEGATIVE_PROMPT: bool = False
WEIGHTS: List[float] = []
INDEX: int = 0
PATCHED = False
class ChunkWeight(scripts.Script):
_error_logged: bool = False
def title(self):
return "Chunk Weight"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label=self.title()) as enable:
weights = gr.Textbox(
label="Weighting (comma separated floats)",
placeholder="1.0, 1.5, 0.8",
value="",
lines=1,
max_lines=1,
)
weights.do_not_save_to_config = True
self.infotext_fields = [PasteField(weights, "Chunk Weights")]
return [enable, weights]
def setup(self, p: "StableDiffusionProcessing", enable: bool, weights: str):
WEIGHTS.clear()
ChunkWeight._error_logged = False
if not enable:
return
for v in weights.split(","):
v = v.strip()
if not v: continue
try:
WEIGHTS.append(float(v))
except ValueError:
logger.error(f'Failed to parse "{v}" as number...')
continue
p.extra_generation_params["Chunk Weights"] = ", ".join(str(v) for v in WEIGHTS)
p.cached_c = [None, None]
p.cached_uc = [None, None]
if hasattr(p, 'cached_hr_c'):
p.cached_hr_c = [None, None]
p.cached_hr_uc = [None, None]
def postprocess(self, *args):
StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
# ==============================================================================
# ЧАСТЬ 3: ПАТЧИНГ
# ==============================================================================
def patch_embedder(cls):
if hasattr(cls, '_chunk_weight_patched'):
return
original_process_texts = cls.process_texts
original_process_tokens = cls.process_tokens
@wraps(original_process_texts)
def patched_process_texts(self, texts: List[str]):
global IS_NEGATIVE_PROMPT, INDEX
if hasattr(texts, "is_negative_prompt"):
IS_NEGATIVE_PROMPT = texts.is_negative_prompt
else:
IS_NEGATIVE_PROMPT = False
INDEX = 0
return original_process_texts(self, texts)
@wraps(original_process_tokens)
def patched_process_tokens(self, remade_batch_tokens: list, batch_multipliers: list):
global INDEX, WEIGHTS, IS_NEGATIVE_PROMPT
if not WEIGHTS:
return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
if INDEX >= 0 and not IS_NEGATIVE_PROMPT:
batches: int = len(batch_multipliers)
if INDEX < len(WEIGHTS):
current_weight = WEIGHTS[INDEX]
logger.debug(f"Applying weight {current_weight}x to Chunk {INDEX}")
for b in range(batches):
for i in range(len(batch_multipliers[b])):
batch_multipliers[b][i] *= current_weight
else:
if not ChunkWeight._error_logged:
logger.warning(f"Not enough weights provided! Chunk {INDEX} uses default weight 1.0.")
ChunkWeight._error_logged = True
INDEX += 1
return original_process_tokens(self, remade_batch_tokens, batch_multipliers)
cls.process_texts = patched_process_texts
cls.process_tokens = patched_process_tokens
cls._chunk_weight_patched = True
cls._original_process_texts = original_process_texts
cls._original_process_tokens = original_process_tokens
def apply_patches(*args, **kwargs):
"""
Применяет патчи.
Принимает *args, **kwargs, так как A1111 передает (demo, app) в этот коллбек.
"""
global PATCHED
if PATCHED: return
patched_count = 0
for cls in target_classes:
try:
patch_embedder(cls)
patched_count += 1
logger.info(f"Chunk Weight logic attached to {cls.__name__}.")
except Exception as e:
logger.error(f"Failed to patch {cls}: {e}")
if patched_count > 0:
PATCHED = True
else:
logger.error("COULD NOT FIND EMBEDDER CLASSES. Chunk Weights will not work.")
def remove_patches():
global PATCHED
if not PATCHED: return
for cls in target_classes:
if hasattr(cls, '_chunk_weight_patched'):
cls.process_texts = cls._original_process_texts
cls.process_tokens = cls._original_process_tokens
del cls._chunk_weight_patched
del cls._original_process_texts
del cls._original_process_tokens
PATCHED = False
logger.info("Chunk Weight patches removed.")
on_app_started(apply_patches)
on_script_unloaded(remove_patches)