import logging import torch import gradio as gr from modules import scripts, shared from modules.processing import StableDiffusionProcessing from modules.infotext_utils import PasteField from modules.ui_components import InputAccordion # Настройка логгера logger = logging.getLogger("ChunkWeight") logger.setLevel(logging.INFO) # ============================================================================ # ГЛОБАЛЬНОЕ СОСТОЯНИЕ # ============================================================================ STATE = { 'pos_weights': [], 'neg_weights': [], 'enabled': False, 'original_method': None, } # ============================================================================ # ЛОГИКА ОБРАБОТКИ ТЕНЗОРОВ (FIXED FOR SDXL) # ============================================================================ def apply_weight_to_cond(cond, weight): """ Применяет вес к кондишену (Тензору или Словарю). """ if weight == 1.0: return cond if isinstance(cond, dict): # Логика для SDXL (Dict wrapper) new_cond = cond.copy() # Умножаем текстовые эмбеддинги (Cross-Attention) for key in ['crossattn', 'c_crossattn', 'open_clip_projected']: if key in new_cond: new_cond[key] = new_cond[key] * weight # Вектора стиля (pooled) 'vector' мы здесь НЕ умножаем на вес скалярно, # так как это может сломать нормализацию. Их вес будет учтен при усреднении (merge_conds). return new_cond elif isinstance(cond, torch.Tensor): # Логика для SD1.5 (Простой тензор) return cond * weight return cond def merge_conds(cond_list, weights=None): """ Склеивает список кондишенов чанков обратно в один промпт. ИСПРАВЛЕНО: Корректная обработка SDXL Pooled Vectors. """ if not cond_list: return None first = cond_list[0] # --- Склейка для SDXL (Dictionary) --- if isinstance(first, dict): merged = {} for key in first.keys(): tensors = [c[key] for c in cond_list if key in c] if not tensors: continue # Проверяем размерность, чтобы понять, как склеивать ndim = len(tensors[0].shape) if ndim == 3: # [Batch, Tokens, Dim] -> CrossAttention. Склеиваем последовательно (в длину). merged[key] = torch.cat(tensors, dim=1) elif ndim == 2: # [Batch, Dim] -> Pooled Vector. Склеивать нельзя (ошибка mat1/mat2)! # Нужно усреднить вектора всех чанков. if weights and len(weights) == len(tensors): # Взвешенное среднее: (V1*w1 + V2*w2) / (w1+w2) # Это позволяет "весу чанка" влиять на глобальный стиль stacked = torch.stack(tensors) # [N, B, D] # Приводим веса к форме [N, 1, 1] для умножения w_tensor = torch.tensor(weights, device=stacked.device, dtype=stacked.dtype).view(-1, 1, 1) weighted_sum = (stacked * w_tensor).sum(dim=0) # [B, D] total_weight = sum(weights) if sum(weights) != 0 else 1.0 merged[key] = weighted_sum / total_weight else: # Простое среднее, если весов нет merged[key] = torch.stack(tensors).mean(dim=0) else: # Фолбэк для странных размерностей merged[key] = tensors[0] return merged # --- Склейка для SD1.5 (Tensor) --- elif isinstance(first, torch.Tensor): # Здесь всегда [Batch, Tokens, Dim], просто склеиваем return torch.cat(cond_list, dim=1) return first def patched_get_learned_conditioning(prompts): """ Подмененный метод получения эмбеддингов. """ global STATE original_method = STATE['original_method'] # Фолбэк безопасности if not STATE['enabled']: return original_method(prompts) if isinstance(prompts, str): prompts = [prompts] final_results = [] for i, prompt in enumerate(prompts): # 1. Разбиваем по BREAK chunks = prompt.split("BREAK") chunk_tensors = [] # 2. Определяем веса для текущего промпта # Эвристика: если кол-во весов совпадает с кол-вом чанков - используем их. # Это позволяет отличить Pos от Neg промпта, если у них разное кол-во чанков. current_weights = [] if len(STATE['pos_weights']) >= len(chunks) and len(STATE['pos_weights']) > 0: current_weights = STATE['pos_weights'][:len(chunks)] # Берем ровно столько, сколько чанков elif len(STATE['neg_weights']) >= len(chunks) and len(STATE['neg_weights']) > 0: current_weights = STATE['neg_weights'][:len(chunks)] else: current_weights = [1.0] * len(chunks) # 3. Обработка чанков for idx, chunk_text in enumerate(chunks): # Получаем эмбеддинг чанка (Оригинальный метод) cond = original_method([chunk_text]) # Получаем вес w = current_weights[idx] # Применяем вес (только к crossattn) if w != 1.0: cond = apply_weight_to_cond(cond, w) chunk_tensors.append(cond) # 4. Склеиваем (с учетом весов для Pooled векторов) merged = merge_conds(chunk_tensors, weights=current_weights) final_results.append(merged) # 5. Собираем итоговый батч (dim=0) if len(final_results) > 1: if isinstance(final_results[0], dict): # Batching для SDXL словарей batch_merged = {} for key in final_results[0].keys(): batch_merged[key] = torch.cat([r[key] for r in final_results], dim=0) return batch_merged else: # Batching для SD1.5 тензоров return torch.cat(final_results, dim=0) else: return final_results[0] # ============================================================================ # ИНТЕРФЕЙС # ============================================================================ class ChunkWeightUltimateFixed(scripts.Script): def title(self): return "Chunk Weight (Ultimate SDXL Fix)" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with InputAccordion(False, label="Chunk Weights") as enable: gr.Markdown("Версия с исправленной поддержкой SDXL. Разбивает по `BREAK`.") pos_weights = gr.Textbox(label="Positive Weights", placeholder="1.2, 0.8", lines=1) neg_weights = gr.Textbox(label="Negative Weights", placeholder="1.0, 0.5", lines=1) self.infotext_fields = [ PasteField(pos_weights, "ChunkW+"), PasteField(neg_weights, "ChunkW-"), ] return [enable, pos_weights, neg_weights] def process(self, p: StableDiffusionProcessing, enable: bool, pos_str: str, neg_str: str): global STATE self.remove_patch() # Очистка старых патчей if not enable: return def parse(s): try: return [float(x.strip()) for x in s.split(',') if x.strip()] except: return [] STATE['pos_weights'] = parse(pos_str) STATE['neg_weights'] = parse(neg_str) STATE['enabled'] = True if STATE['pos_weights']: p.extra_generation_params["ChunkW+"] = str(STATE['pos_weights']) if STATE['neg_weights']: p.extra_generation_params["ChunkW-"] = str(STATE['neg_weights']) # Патчинг get_learned_conditioning if shared.sd_model and hasattr(shared.sd_model, 'get_learned_conditioning'): logger.info("ChunkWeight: Patching model...") STATE['original_method'] = shared.sd_model.get_learned_conditioning shared.sd_model.get_learned_conditioning = patched_get_learned_conditioning # Сброс кэшей A1111 (Обязательно!) p.cached_c = [None, None] p.cached_uc = [None, None] p.cached_hr_c = [None, None] p.cached_hr_uc = [None, None] def postprocess(self, p, processed, *args): self.remove_patch() def remove_patch(self): global STATE if STATE['original_method'] and shared.sd_model: shared.sd_model.get_learned_conditioning = STATE['original_method'] STATE['original_method'] = None STATE['enabled'] = False def on_unload(): global STATE if STATE['original_method'] and shared.sd_model: shared.sd_model.get_learned_conditioning = STATE['original_method'] scripts.script_callbacks.on_script_unloaded(on_unload)