| | import logging |
| | import torch |
| | import gradio as gr |
| | from modules import scripts, shared |
| | from modules.processing import StableDiffusionProcessing |
| | from modules.infotext_utils import PasteField |
| | from modules.ui_components import InputAccordion |
| |
|
| | |
| | logger = logging.getLogger("ChunkWeight") |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | |
| | |
| | STATE = { |
| | 'pos_weights': [], |
| | 'neg_weights': [], |
| | 'enabled': False, |
| | 'original_method': None, |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | def apply_weight_to_cond(cond, weight): |
| | """ |
| | Применяет вес к кондишену (Тензору или Словарю). |
| | """ |
| | if weight == 1.0: |
| | return cond |
| |
|
| | if isinstance(cond, dict): |
| | |
| | new_cond = cond.copy() |
| | |
| | |
| | for key in ['crossattn', 'c_crossattn', 'open_clip_projected']: |
| | if key in new_cond: |
| | new_cond[key] = new_cond[key] * weight |
| | |
| | |
| | |
| | return new_cond |
| | |
| | elif isinstance(cond, torch.Tensor): |
| | |
| | return cond * weight |
| | |
| | return cond |
| |
|
| |
|
| | def merge_conds(cond_list, weights=None): |
| | """ |
| | Склеивает список кондишенов чанков обратно в один промпт. |
| | ИСПРАВЛЕНО: Корректная обработка SDXL Pooled Vectors. |
| | """ |
| | if not cond_list: |
| | return None |
| | |
| | first = cond_list[0] |
| | |
| | |
| | if isinstance(first, dict): |
| | merged = {} |
| | for key in first.keys(): |
| | tensors = [c[key] for c in cond_list if key in c] |
| | if not tensors: |
| | continue |
| | |
| | |
| | ndim = len(tensors[0].shape) |
| | |
| | if ndim == 3: |
| | |
| | merged[key] = torch.cat(tensors, dim=1) |
| | |
| | elif ndim == 2: |
| | |
| | |
| | |
| | if weights and len(weights) == len(tensors): |
| | |
| | |
| | stacked = torch.stack(tensors) |
| | |
| | |
| | w_tensor = torch.tensor(weights, device=stacked.device, dtype=stacked.dtype).view(-1, 1, 1) |
| | |
| | weighted_sum = (stacked * w_tensor).sum(dim=0) |
| | total_weight = sum(weights) if sum(weights) != 0 else 1.0 |
| | |
| | merged[key] = weighted_sum / total_weight |
| | else: |
| | |
| | merged[key] = torch.stack(tensors).mean(dim=0) |
| | else: |
| | |
| | merged[key] = tensors[0] |
| | |
| | return merged |
| |
|
| | |
| | elif isinstance(first, torch.Tensor): |
| | |
| | return torch.cat(cond_list, dim=1) |
| | |
| | return first |
| |
|
| |
|
| | def patched_get_learned_conditioning(prompts): |
| | """ |
| | Подмененный метод получения эмбеддингов. |
| | """ |
| | global STATE |
| | original_method = STATE['original_method'] |
| | |
| | |
| | if not STATE['enabled']: |
| | return original_method(prompts) |
| |
|
| | if isinstance(prompts, str): |
| | prompts = [prompts] |
| |
|
| | final_results = [] |
| |
|
| | for i, prompt in enumerate(prompts): |
| | |
| | chunks = prompt.split("BREAK") |
| | chunk_tensors = [] |
| | |
| | |
| | |
| | |
| | current_weights = [] |
| | |
| | if len(STATE['pos_weights']) >= len(chunks) and len(STATE['pos_weights']) > 0: |
| | current_weights = STATE['pos_weights'][:len(chunks)] |
| | elif len(STATE['neg_weights']) >= len(chunks) and len(STATE['neg_weights']) > 0: |
| | current_weights = STATE['neg_weights'][:len(chunks)] |
| | else: |
| | current_weights = [1.0] * len(chunks) |
| |
|
| | |
| | for idx, chunk_text in enumerate(chunks): |
| | |
| | cond = original_method([chunk_text]) |
| | |
| | |
| | w = current_weights[idx] |
| | |
| | |
| | if w != 1.0: |
| | cond = apply_weight_to_cond(cond, w) |
| | |
| | chunk_tensors.append(cond) |
| |
|
| | |
| | merged = merge_conds(chunk_tensors, weights=current_weights) |
| | final_results.append(merged) |
| |
|
| | |
| | if len(final_results) > 1: |
| | if isinstance(final_results[0], dict): |
| | |
| | batch_merged = {} |
| | for key in final_results[0].keys(): |
| | batch_merged[key] = torch.cat([r[key] for r in final_results], dim=0) |
| | return batch_merged |
| | else: |
| | |
| | return torch.cat(final_results, dim=0) |
| | else: |
| | return final_results[0] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ChunkWeightUltimateFixed(scripts.Script): |
| | def title(self): |
| | return "Chunk Weight (Ultimate SDXL Fix)" |
| |
|
| | def show(self, is_img2img): |
| | return scripts.AlwaysVisible |
| |
|
| | def ui(self, is_img2img): |
| | with InputAccordion(False, label="Chunk Weights") as enable: |
| | gr.Markdown("Версия с исправленной поддержкой SDXL. Разбивает по `BREAK`.") |
| | pos_weights = gr.Textbox(label="Positive Weights", placeholder="1.2, 0.8", lines=1) |
| | neg_weights = gr.Textbox(label="Negative Weights", placeholder="1.0, 0.5", lines=1) |
| |
|
| | self.infotext_fields = [ |
| | PasteField(pos_weights, "ChunkW+"), |
| | PasteField(neg_weights, "ChunkW-"), |
| | ] |
| | return [enable, pos_weights, neg_weights] |
| |
|
| | def process(self, p: StableDiffusionProcessing, enable: bool, pos_str: str, neg_str: str): |
| | global STATE |
| | |
| | self.remove_patch() |
| | |
| | if not enable: |
| | return |
| |
|
| | def parse(s): |
| | try: return [float(x.strip()) for x in s.split(',') if x.strip()] |
| | except: return [] |
| |
|
| | STATE['pos_weights'] = parse(pos_str) |
| | STATE['neg_weights'] = parse(neg_str) |
| | STATE['enabled'] = True |
| | |
| | if STATE['pos_weights']: p.extra_generation_params["ChunkW+"] = str(STATE['pos_weights']) |
| | if STATE['neg_weights']: p.extra_generation_params["ChunkW-"] = str(STATE['neg_weights']) |
| |
|
| | |
| | if shared.sd_model and hasattr(shared.sd_model, 'get_learned_conditioning'): |
| | logger.info("ChunkWeight: Patching model...") |
| | STATE['original_method'] = shared.sd_model.get_learned_conditioning |
| | shared.sd_model.get_learned_conditioning = patched_get_learned_conditioning |
| | |
| | |
| | p.cached_c = [None, None] |
| | p.cached_uc = [None, None] |
| | p.cached_hr_c = [None, None] |
| | p.cached_hr_uc = [None, None] |
| |
|
| | def postprocess(self, p, processed, *args): |
| | self.remove_patch() |
| |
|
| | def remove_patch(self): |
| | global STATE |
| | if STATE['original_method'] and shared.sd_model: |
| | shared.sd_model.get_learned_conditioning = STATE['original_method'] |
| | STATE['original_method'] = None |
| | STATE['enabled'] = False |
| |
|
| | def on_unload(): |
| | global STATE |
| | if STATE['original_method'] and shared.sd_model: |
| | shared.sd_model.get_learned_conditioning = STATE['original_method'] |
| |
|
| | scripts.script_callbacks.on_script_unloaded(on_unload) |