dikdimon's picture
Upload sd-webui-xl_vec using SD-Hub
8c0bf45 verified
NAME = 'XL Vec'
import logging
import traceback
from threading import Lock
from torch import Tensor, FloatTensor, nn
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import scripts
from scripts.sdhook import SDHook
from scripts.xl_clip import CLIP_SDXL, get_pooled
from scripts.xl_vec_xyz import init_xyz
# --- LOGGING ---
logger = logging.getLogger(__name__)
# --- CONSTANTS ---
SDXL_POOLED_DIM = 1280 # Размер pooled embedding вектора SDXL
AESTHETIC_SCORE_EPS = 0.01 # Допуск для сравнения float значений (aesthetic score)
DEFAULT_AESTHETIC_SCORE = 6.0
# --- PRESETS ---
PRESETS = {
"Manual / Custom": None,
"1:1 Square (1024x1024)": (1024, 1024),
"4:3 Photo (1152x896)": (1152, 896),
"3:4 Portrait (896x1152)": (896, 1152),
"16:9 Cinema (1344x768)": (1344, 768),
"9:16 Mobile (768x1344)": (768, 1344),
"21:9 Wide (1536x640)": (1536, 640),
"2:3 Classic (832x1216)": (832, 1216),
}
def hook_input(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]]) -> tuple[dict[str, Tensor]]:
"""
Перехватывает входные данные CLIP модели для подмены параметров conditioning
(размеры, кроп, эстетическая оценка).
Args:
args: Экземпляр Hook с параметрами
mod: CLIP модуль
inputs: Tuple с входными данными
Returns:
Модифицированные входные данные
"""
if not args.enabled:
return inputs
assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}"
input_data = inputs[0]
def create(v: list[float], src: FloatTensor) -> FloatTensor:
"""Создает тензор с правильным device и dtype."""
return FloatTensor(v).to(dtype=src.dtype, device=src.device)
def put(name: str, v: list[float]) -> None:
"""Безопасно заменяет значение в input_data."""
if name in input_data:
src = input_data[name]
input_data[name] = create(v, src).reshape(src.shape)
# Применяем геометрические параметры
put('original_size_as_tuple', [args.original_height, args.original_width])
put('crop_coords_top_left', [args.crop_top, args.crop_left])
put('target_size_as_tuple', [args.target_height, args.target_width])
# Логика определения Positive/Negative промпта через Aesthetic Score
try:
current_score = input_data['aesthetic_score'].item()
if args.is_positive_prompt(current_score):
put('aesthetic_score', [args.aesthetic_score])
else:
put('aesthetic_score', [args.negative_aesthetic_score])
except (KeyError, AttributeError) as e:
logger.warning(f"[XL Vec] Cannot access aesthetic_score: {e}")
return inputs
def hook_output(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]], output: dict) -> dict:
"""
Перехватывает выход CLIP модели для замены векторов токенов.
Args:
args: Экземпляр Hook с параметрами
mod: CLIP модуль
inputs: Входные данные
output: Выходные данные с ключом 'vector'
Returns:
Модифицированные выходные данные
"""
if not args.enabled:
return output
try:
# Определяем, работаем ли мы с Positive или Negative промптом
current_score = inputs[0]['aesthetic_score'].item()
prompt, index, multiplier = args.get_prompt_params(current_score)
# Если параметры не заданы пользователем, ничего не делаем
if (prompt is None or len(prompt) == 0) and (index == -1 and multiplier == 1.0):
return output
# Если текст замены не задан, используем оригинальный промпт
if prompt is None or len(prompt) == 0:
prompt = inputs[0]['txt'][0]
assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}"
# Получаем новый pooled embedding с защитой от рекурсии
with args._lock:
args.enabled = False
try:
pooled, token_idx = get_pooled(mod, prompt, index=index)
finally:
args.enabled = True
# Подмена вектора с проверкой размерности
if output['vector'].shape[1] >= SDXL_POOLED_DIM:
output['vector'][:, 0:SDXL_POOLED_DIM] = pooled[:] * multiplier
logger.info(
f"[XL Vec] Vector override: '{inputs[0]['txt']}' -> '{prompt}' "
f"@ token {token_idx} [x{multiplier:.2f}]"
)
else:
logger.error(
f"[XL Vec] Vector dimension mismatch: expected >={SDXL_POOLED_DIM}, "
f"got {output['vector'].shape[1]}"
)
except Exception as e:
logger.error(f"[XL Vec] Error in hook_output: {e}")
traceback.print_exc()
return output
class Hook(SDHook):
"""Хук для модификации CLIP conditioning в SDXL."""
def __init__(
self,
enabled: bool,
p: StableDiffusionProcessing,
crop_left: float, crop_top: float,
original_width: float, original_height: float,
target_width: float, target_height: float,
aesthetic_score: float, negative_aesthetic_score: float,
extra_prompt: str | None, extra_negative_prompt: str | None,
token_index: int | float, negative_token_index: int | float,
eot_multiplier: float, negative_eot_multiplier: float,
with_hr: bool,
base_aesthetic_score: float,
):
super().__init__(enabled)
# Валидация параметров
self._validate_params(
aesthetic_score, negative_aesthetic_score, base_aesthetic_score,
original_width, original_height, target_width, target_height
)
self.p = p
self.crop_left = float(crop_left)
self.crop_top = float(crop_top)
self.original_width = float(original_width)
self.original_height = float(original_height)
self.target_width = float(target_width)
self.target_height = float(target_height)
self.aesthetic_score = float(aesthetic_score)
self.negative_aesthetic_score = float(negative_aesthetic_score)
self.extra_prompt = extra_prompt
self.extra_negative_prompt = extra_negative_prompt
self.token_index = int(token_index)
self.negative_token_index = int(negative_token_index)
self.eot_multiplier = float(eot_multiplier)
self.negative_eot_multiplier = float(negative_eot_multiplier)
self.with_hr = bool(with_hr)
self.base_aesthetic_score = float(base_aesthetic_score)
# Thread safety для предотвращения race conditions
self._lock = Lock()
@staticmethod
def _validate_params(
aesthetic_score: float,
negative_aesthetic_score: float,
base_aesthetic_score: float,
original_width: float,
original_height: float,
target_width: float,
target_height: float
) -> None:
"""Валидирует входные параметры."""
for score, name in [
(aesthetic_score, "aesthetic_score"),
(negative_aesthetic_score, "negative_aesthetic_score"),
(base_aesthetic_score, "base_aesthetic_score")
]:
if not (0 <= score <= 10):
raise ValueError(f"{name} должен быть в диапазоне [0, 10], получено {score}")
for size, name in [
(original_width, "original_width"),
(original_height, "original_height"),
(target_width, "target_width"),
(target_height, "target_height")
]:
if size < 0:
raise ValueError(f"{name} не может быть отрицательным, получено {size}")
def is_positive_prompt(self, aesthetic_score: float) -> bool:
"""
Определяет, является ли текущий промпт положительным.
Args:
aesthetic_score: Текущее значение aesthetic score
Returns:
True если это positive prompt, False если negative
"""
return abs(aesthetic_score - self.base_aesthetic_score) < AESTHETIC_SCORE_EPS
def get_prompt_params(self, aesthetic_score: float) -> tuple[str | None, int, float]:
"""
Возвращает параметры промпта в зависимости от aesthetic_score.
Args:
aesthetic_score: Текущее значение aesthetic score
Returns:
Tuple (prompt, token_index, multiplier)
"""
if self.is_positive_prompt(aesthetic_score):
return self.extra_prompt, self.token_index, self.eot_multiplier
else:
return self.extra_negative_prompt, self.negative_token_index, self.negative_eot_multiplier
def cleanup(self) -> None:
"""Корректно удаляет все хуки."""
try:
self.__exit__(None, None, None)
except Exception as e:
logger.warning(f"[XL Vec] Error during cleanup: {e}")
def hook_clip(self, p: StableDiffusionProcessing, clip: nn.Module) -> None:
"""Устанавливает хуки на CLIP модель."""
if not hasattr(p.sd_model, 'is_sdxl') or not p.sd_model.is_sdxl:
logger.debug("[XL Vec] Model is not SDXL, skipping hooks")
return
def inp(*args, **kwargs):
return hook_input(self, *args, **kwargs)
def outp(*args, **kwargs):
return hook_output(self, *args, **kwargs)
self.hook_layer_pre(clip, inp)
self.hook_layer(clip, outp)
class Script(scripts.Script):
"""Скрипт для управления SDXL conditioning параметрами."""
def title(self) -> str:
return NAME
def show(self, is_img2img) -> scripts.AlwaysVisible:
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion(NAME, open=False):
with gr.Row():
enabled = gr.Checkbox(label='Enable XL Vec', value=False)
with_hr = gr.Checkbox(label='Active on Hires Fix', value=False, visible=False)
# --- GEOMETRY SECTION ---
with gr.Group():
gr.Markdown("### 📐 SDXL Geometry & Size")
preset_dropdown = gr.Dropdown(
label="⚡ Quick Resolution Preset",
choices=list(PRESETS.keys()),
value="Manual / Custom",
type="value"
)
with gr.Row():
original_width = gr.Slider(
minimum=-1, maximum=4096, step=1, value=-1,
label='Original Width (-1=auto)'
)
original_height = gr.Slider(
minimum=-1, maximum=4096, step=1, value=-1,
label='Original Height (-1=auto)'
)
with gr.Row():
target_width = gr.Slider(
minimum=-1, maximum=4096, step=1, value=-1,
label='Target Width (-1=auto)'
)
target_height = gr.Slider(
minimum=-1, maximum=4096, step=1, value=-1,
label='Target Height (-1=auto)'
)
# Callback: Dropdown -> Sliders
def apply_preset(choice):
if choice in PRESETS and PRESETS[choice] is not None:
w, h = PRESETS[choice]
return w, h, w, h
return gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[original_width, original_height, target_width, target_height]
)
# Callback: Sliders -> Dropdown (Reset to Manual)
def reset_dropdown():
return "Manual / Custom"
for slider in [original_width, original_height, target_width, target_height]:
slider.change(fn=reset_dropdown, inputs=None, outputs=[preset_dropdown])
with gr.Accordion("✂️ Crop Settings", open=False):
with gr.Row():
crop_left = gr.Slider(
minimum=-10000, maximum=10000, step=1, value=0,
label='Crop Left'
)
crop_top = gr.Slider(
minimum=-10000, maximum=10000, step=1, value=0,
label='Crop Top'
)
# --- AESTHETICS SECTION ---
with gr.Group():
gr.Markdown("### 🎨 Aesthetics")
with gr.Row():
aesthetic_score = gr.Slider(
minimum=0.0, maximum=10.0, step=0.1, value=6.0,
label="Positive Aesthetic Score"
)
negative_aesthetic_score = gr.Slider(
minimum=0.0, maximum=10.0, step=0.1, value=2.5,
label="Negative Aesthetic Score"
)
with gr.Accordion("⚙️ Detection Threshold (Advanced)", open=False):
base_aesthetic_score = gr.Slider(
minimum=0.0, maximum=10.0, step=0.1, value=6.0,
label="Base Score Threshold"
)
gr.Info("Change this ONLY if you modified 'SDXL Aesthetic Score' in WebUI settings.")
# --- VECTORS SECTION ---
with gr.Accordion("🧠 Token & Vector Control", open=False):
with gr.Row():
eot_multiplier = gr.Slider(
minimum=-4.0, maximum=8.0, step=0.05, value=1.0,
label='Pos. Vector Mult'
)
negative_eot_multiplier = gr.Slider(
minimum=-4.0, maximum=8.0, step=0.05, value=1.0,
label='Neg. Vector Mult'
)
with gr.Row():
token_index = gr.Slider(
minimum=-77, maximum=76, step=1, value=-1,
label='Pos. Token Index'
)
negative_token_index = gr.Slider(
minimum=-77, maximum=76, step=1, value=-1,
label='Neg. Token Index'
)
with gr.Row():
extra_prompt = gr.Textbox(
lines=1, label='Extra Prompt',
placeholder="Override positive prompt text..."
)
extra_negative_prompt = gr.Textbox(
lines=1, label='Extra Negative',
placeholder="Override negative prompt text..."
)
return [
enabled, crop_left, crop_top, original_width, original_height,
target_width, target_height, aesthetic_score, negative_aesthetic_score,
extra_prompt, extra_negative_prompt, token_index, negative_token_index,
eot_multiplier, negative_eot_multiplier, with_hr,
base_aesthetic_score
]
def process(
self, p, enabled, crop_left, crop_top, original_width, original_height,
target_width, target_height, aesthetic_score, negative_aesthetic_score,
extra_prompt, extra_negative_prompt, token_index, negative_token_index,
eot_multiplier, negative_eot_multiplier, with_hr,
base_aesthetic_score=DEFAULT_AESTHETIC_SCORE
):
"""Обрабатывает параметры и устанавливает хуки перед генерацией."""
# Очистка предыдущего хука (если остался)
if getattr(self, 'last_hooker', None) is not None:
self.last_hooker.cleanup()
self.last_hooker = None
if not enabled:
return
# Автозаполнение размеров
if original_width < 0:
original_width = p.width
if original_height < 0:
original_height = p.height
if target_width < 0:
target_width = p.width
if target_height < 0:
target_height = p.height
try:
self.last_hooker = Hook(
enabled=True, p=p,
crop_left=crop_left, crop_top=crop_top,
original_width=original_width, original_height=original_height,
target_width=target_width, target_height=target_height,
aesthetic_score=aesthetic_score,
negative_aesthetic_score=negative_aesthetic_score,
extra_prompt=extra_prompt, extra_negative_prompt=extra_negative_prompt,
token_index=token_index, negative_token_index=negative_token_index,
eot_multiplier=eot_multiplier,
negative_eot_multiplier=negative_eot_multiplier,
with_hr=with_hr, base_aesthetic_score=base_aesthetic_score
)
except ValueError as e:
logger.error(f"[XL Vec] Invalid parameters: {e}")
return
self.last_hooker.setup(p)
self.last_hooker.__enter__()
# Обновление метаданных (Infotext)
p.extra_generation_params.update({
f'[{NAME}] Enabled': enabled,
f'[{NAME}] Original Size': f"{int(original_width)}x{int(original_height)}",
f'[{NAME}] Target Size': f"{int(target_width)}x{int(target_height)}",
f'[{NAME}] Aesthetic Score': aesthetic_score,
})
if crop_left != 0 or crop_top != 0:
p.extra_generation_params[f'[{NAME}] Crop'] = f"{crop_left},{crop_top}"
if abs(base_aesthetic_score - DEFAULT_AESTHETIC_SCORE) > AESTHETIC_SCORE_EPS:
p.extra_generation_params[f'[{NAME}] Base Score'] = base_aesthetic_score
if eot_multiplier != 1.0:
p.extra_generation_params[f'[{NAME}] Token Mult'] = eot_multiplier
# Сброс кэша для применения новых условий
if hasattr(p, 'cached_c'):
p.cached_c = [None, None]
if hasattr(p, 'cached_uc'):
p.cached_uc = [None, None]
init_xyz(Script, NAME)