| | import logging |
| | import sys |
| | from functools import wraps |
| | from typing import List |
| |
|
| | import gradio as gr |
| | from modules import scripts |
| | from modules.infotext_utils import PasteField |
| | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img |
| | from modules.script_callbacks import on_app_started, on_script_unloaded |
| | from modules.ui_components import InputAccordion |
| |
|
| | |
| | |
| | |
| | class ColorCode: |
| | RESET = "\033[0m" |
| | BLACK = "\033[0;90m" |
| | CYAN = "\033[0;36m" |
| | YELLOW = "\033[0;33m" |
| | RED = "\033[0;31m" |
| |
|
| | MAP = { |
| | "DEBUG": BLACK, |
| | "INFO": CYAN, |
| | "WARNING": YELLOW, |
| | "ERROR": RED, |
| | } |
| |
|
| | class ColoredFormatter(logging.Formatter): |
| | def format(self, record): |
| | levelname = record.levelname |
| | if levelname in ColorCode.MAP: |
| | record.levelname = f"{ColorCode.MAP[levelname]}{levelname}{ColorCode.RESET}" |
| | return super().format(record) |
| |
|
| | logger = logging.getLogger("ChunkWeight") |
| | logger.setLevel(logging.INFO) |
| | logger.propagate = False |
| |
|
| | if not logger.handlers: |
| | handler = logging.StreamHandler(sys.stdout) |
| | handler.setFormatter(ColoredFormatter("[%(name)s] %(levelname)s - %(message)s")) |
| | logger.addHandler(handler) |
| |
|
| | |
| | |
| | |
| |
|
| | target_classes = [] |
| |
|
| | try: |
| | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWordsBase |
| | target_classes.append(FrozenCLIPEmbedderWithCustomWordsBase) |
| | except ImportError: |
| | logger.warning("Base class FrozenCLIPEmbedderWithCustomWordsBase not found. Trying individual classes.") |
| | try: |
| | from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords |
| | target_classes.append(FrozenCLIPEmbedderWithCustomWords) |
| | except ImportError: pass |
| | |
| | try: |
| | from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedderWithCustomWords |
| | target_classes.append(FrozenOpenCLIPEmbedderWithCustomWords) |
| | except ImportError: pass |
| |
|
| | IS_NEGATIVE_PROMPT: bool = False |
| | WEIGHTS: List[float] = [] |
| | INDEX: int = 0 |
| | PATCHED = False |
| |
|
| | class ChunkWeight(scripts.Script): |
| | _error_logged: bool = False |
| |
|
| | def title(self): |
| | return "Chunk Weight" |
| |
|
| | def show(self, is_img2img): |
| | return scripts.AlwaysVisible |
| |
|
| | def ui(self, is_img2img): |
| | with InputAccordion(False, label=self.title()) as enable: |
| | weights = gr.Textbox( |
| | label="Weighting (comma separated floats)", |
| | placeholder="1.0, 1.5, 0.8", |
| | value="", |
| | lines=1, |
| | max_lines=1, |
| | ) |
| | weights.do_not_save_to_config = True |
| |
|
| | self.infotext_fields = [PasteField(weights, "Chunk Weights")] |
| | return [enable, weights] |
| |
|
| | def setup(self, p: "StableDiffusionProcessing", enable: bool, weights: str): |
| | WEIGHTS.clear() |
| | ChunkWeight._error_logged = False |
| |
|
| | if not enable: |
| | return |
| |
|
| | for v in weights.split(","): |
| | v = v.strip() |
| | if not v: continue |
| | try: |
| | WEIGHTS.append(float(v)) |
| | except ValueError: |
| | logger.error(f'Failed to parse "{v}" as number...') |
| | continue |
| |
|
| | p.extra_generation_params["Chunk Weights"] = ", ".join(str(v) for v in WEIGHTS) |
| |
|
| | p.cached_c = [None, None] |
| | p.cached_uc = [None, None] |
| | |
| | if hasattr(p, 'cached_hr_c'): |
| | p.cached_hr_c = [None, None] |
| | p.cached_hr_uc = [None, None] |
| |
|
| | def postprocess(self, *args): |
| | StableDiffusionProcessing.cached_c = [None, None] |
| | StableDiffusionProcessing.cached_uc = [None, None] |
| | StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] |
| | StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def patch_embedder(cls): |
| | if hasattr(cls, '_chunk_weight_patched'): |
| | return |
| |
|
| | original_process_texts = cls.process_texts |
| | original_process_tokens = cls.process_tokens |
| |
|
| | @wraps(original_process_texts) |
| | def patched_process_texts(self, texts: List[str]): |
| | global IS_NEGATIVE_PROMPT, INDEX |
| |
|
| | if hasattr(texts, "is_negative_prompt"): |
| | IS_NEGATIVE_PROMPT = texts.is_negative_prompt |
| | else: |
| | IS_NEGATIVE_PROMPT = False |
| |
|
| | INDEX = 0 |
| | return original_process_texts(self, texts) |
| |
|
| | @wraps(original_process_tokens) |
| | def patched_process_tokens(self, remade_batch_tokens: list, batch_multipliers: list): |
| | global INDEX, WEIGHTS, IS_NEGATIVE_PROMPT |
| |
|
| | if not WEIGHTS: |
| | return original_process_tokens(self, remade_batch_tokens, batch_multipliers) |
| |
|
| | if INDEX >= 0 and not IS_NEGATIVE_PROMPT: |
| | batches: int = len(batch_multipliers) |
| | |
| | if INDEX < len(WEIGHTS): |
| | current_weight = WEIGHTS[INDEX] |
| | logger.debug(f"Applying weight {current_weight}x to Chunk {INDEX}") |
| | |
| | for b in range(batches): |
| | for i in range(len(batch_multipliers[b])): |
| | batch_multipliers[b][i] *= current_weight |
| | else: |
| | if not ChunkWeight._error_logged: |
| | logger.warning(f"Not enough weights provided! Chunk {INDEX} uses default weight 1.0.") |
| | ChunkWeight._error_logged = True |
| | |
| | INDEX += 1 |
| |
|
| | return original_process_tokens(self, remade_batch_tokens, batch_multipliers) |
| |
|
| | cls.process_texts = patched_process_texts |
| | cls.process_tokens = patched_process_tokens |
| | cls._chunk_weight_patched = True |
| | |
| | cls._original_process_texts = original_process_texts |
| | cls._original_process_tokens = original_process_tokens |
| |
|
| |
|
| | def apply_patches(*args, **kwargs): |
| | """ |
| | Применяет патчи. |
| | Принимает *args, **kwargs, так как A1111 передает (demo, app) в этот коллбек. |
| | """ |
| | global PATCHED |
| | if PATCHED: return |
| | |
| | patched_count = 0 |
| | for cls in target_classes: |
| | try: |
| | patch_embedder(cls) |
| | patched_count += 1 |
| | logger.info(f"Chunk Weight logic attached to {cls.__name__}.") |
| | except Exception as e: |
| | logger.error(f"Failed to patch {cls}: {e}") |
| | |
| | if patched_count > 0: |
| | PATCHED = True |
| | else: |
| | logger.error("COULD NOT FIND EMBEDDER CLASSES. Chunk Weights will not work.") |
| |
|
| |
|
| | def remove_patches(): |
| | global PATCHED |
| | if not PATCHED: return |
| | |
| | for cls in target_classes: |
| | if hasattr(cls, '_chunk_weight_patched'): |
| | cls.process_texts = cls._original_process_texts |
| | cls.process_tokens = cls._original_process_tokens |
| | del cls._chunk_weight_patched |
| | del cls._original_process_texts |
| | del cls._original_process_tokens |
| | |
| | PATCHED = False |
| | logger.info("Chunk Weight patches removed.") |
| |
|
| | on_app_started(apply_patches) |
| | on_script_unloaded(remove_patches) |