import logging import sys from functools import wraps from typing import List import gradio as gr from modules import scripts from modules.infotext_utils import PasteField from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img from modules.script_callbacks import on_app_started, on_script_unloaded from modules.ui_components import InputAccordion # ============================================================================== # ЧАСТЬ 1: ЛОГГЕР # ============================================================================== class ColorCode: RESET = "\033[0m" BLACK = "\033[0;90m" CYAN = "\033[0;36m" YELLOW = "\033[0;33m" RED = "\033[0;31m" MAP = { "DEBUG": BLACK, "INFO": CYAN, "WARNING": YELLOW, "ERROR": RED, } class ColoredFormatter(logging.Formatter): def format(self, record): levelname = record.levelname if levelname in ColorCode.MAP: record.levelname = f"{ColorCode.MAP[levelname]}{levelname}{ColorCode.RESET}" return super().format(record) logger = logging.getLogger("ChunkWeight") logger.setLevel(logging.INFO) logger.propagate = False if not logger.handlers: handler = logging.StreamHandler(sys.stdout) handler.setFormatter(ColoredFormatter("[%(name)s] %(levelname)s - %(message)s")) logger.addHandler(handler) # ============================================================================== # ЧАСТЬ 2: СКРИПТ (Логика весов) # ============================================================================== target_classes = [] try: from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase target_classes.append(FrozenCLIPEmbedderWithCustomWordsBase) except ImportError: logger.warning("Base class FrozenCLIPEmbedderWithCustomWordsBase not found. Trying individual classes.") try: from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords target_classes.append(FrozenCLIPEmbedderWithCustomWords) except ImportError: pass try: from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords target_classes.append(FrozenOpenCLIPEmbedderWithCustomWords) except ImportError: pass IS_NEGATIVE_PROMPT: bool = False WEIGHTS: List[float] = [] INDEX: int = 0 PATCHED = False class ChunkWeight(scripts.Script): _error_logged: bool = False def title(self): return "Chunk Weight" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with InputAccordion(False, label=self.title()) as enable: weights = gr.Textbox( label="Weighting (comma separated floats)", placeholder="1.0, 1.5, 0.8", value="", lines=1, max_lines=1, ) weights.do_not_save_to_config = True self.infotext_fields = [PasteField(weights, "Chunk Weights")] return [enable, weights] def setup(self, p: "StableDiffusionProcessing", enable: bool, weights: str): WEIGHTS.clear() ChunkWeight._error_logged = False if not enable: return for v in weights.split(","): v = v.strip() if not v: continue try: WEIGHTS.append(float(v)) except ValueError: logger.error(f'Failed to parse "{v}" as number...') continue p.extra_generation_params["Chunk Weights"] = ", ".join(str(v) for v in WEIGHTS) p.cached_c = [None, None] p.cached_uc = [None, None] if hasattr(p, 'cached_hr_c'): p.cached_hr_c = [None, None] p.cached_hr_uc = [None, None] def postprocess(self, *args): StableDiffusionProcessing.cached_c = [None, None] StableDiffusionProcessing.cached_uc = [None, None] StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] # ============================================================================== # ЧАСТЬ 3: ПАТЧИНГ # ============================================================================== def patch_embedder(cls): if hasattr(cls, '_chunk_weight_patched'): return original_process_texts = cls.process_texts original_process_tokens = cls.process_tokens @wraps(original_process_texts) def patched_process_texts(self, texts: List[str]): global IS_NEGATIVE_PROMPT, INDEX if hasattr(texts, "is_negative_prompt"): IS_NEGATIVE_PROMPT = texts.is_negative_prompt else: IS_NEGATIVE_PROMPT = False INDEX = 0 return original_process_texts(self, texts) @wraps(original_process_tokens) def patched_process_tokens(self, remade_batch_tokens: list, batch_multipliers: list): global INDEX, WEIGHTS, IS_NEGATIVE_PROMPT if not WEIGHTS: return original_process_tokens(self, remade_batch_tokens, batch_multipliers) if INDEX >= 0 and not IS_NEGATIVE_PROMPT: batches: int = len(batch_multipliers) if INDEX < len(WEIGHTS): current_weight = WEIGHTS[INDEX] logger.debug(f"Applying weight {current_weight}x to Chunk {INDEX}") for b in range(batches): for i in range(len(batch_multipliers[b])): batch_multipliers[b][i] *= current_weight else: if not ChunkWeight._error_logged: logger.warning(f"Not enough weights provided! Chunk {INDEX} uses default weight 1.0.") ChunkWeight._error_logged = True INDEX += 1 return original_process_tokens(self, remade_batch_tokens, batch_multipliers) cls.process_texts = patched_process_texts cls.process_tokens = patched_process_tokens cls._chunk_weight_patched = True cls._original_process_texts = original_process_texts cls._original_process_tokens = original_process_tokens def apply_patches(*args, **kwargs): """ Применяет патчи. Принимает *args, **kwargs, так как A1111 передает (demo, app) в этот коллбек. """ global PATCHED if PATCHED: return patched_count = 0 for cls in target_classes: try: patch_embedder(cls) patched_count += 1 logger.info(f"Chunk Weight logic attached to {cls.__name__}.") except Exception as e: logger.error(f"Failed to patch {cls}: {e}") if patched_count > 0: PATCHED = True else: logger.error("COULD NOT FIND EMBEDDER CLASSES. Chunk Weights will not work.") def remove_patches(): global PATCHED if not PATCHED: return for cls in target_classes: if hasattr(cls, '_chunk_weight_patched'): cls.process_texts = cls._original_process_texts cls.process_tokens = cls._original_process_tokens del cls._chunk_weight_patched del cls._original_process_texts del cls._original_process_tokens PATCHED = False logger.info("Chunk Weight patches removed.") on_app_started(apply_patches) on_script_unloaded(remove_patches)