|
|
NAME = 'XL Vec' |
|
|
|
|
|
import logging |
|
|
import traceback |
|
|
from threading import Lock |
|
|
from torch import Tensor, FloatTensor, nn |
|
|
import gradio as gr |
|
|
from modules.processing import StableDiffusionProcessing |
|
|
from modules import scripts |
|
|
|
|
|
from scripts.sdhook import SDHook |
|
|
from scripts.xl_clip import CLIP_SDXL, get_pooled |
|
|
from scripts.xl_vec_xyz import init_xyz |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
SDXL_POOLED_DIM = 1280 |
|
|
AESTHETIC_SCORE_EPS = 0.01 |
|
|
DEFAULT_AESTHETIC_SCORE = 6.0 |
|
|
|
|
|
|
|
|
PRESETS = { |
|
|
"Manual / Custom": None, |
|
|
"1:1 Square (1024x1024)": (1024, 1024), |
|
|
"4:3 Photo (1152x896)": (1152, 896), |
|
|
"3:4 Portrait (896x1152)": (896, 1152), |
|
|
"16:9 Cinema (1344x768)": (1344, 768), |
|
|
"9:16 Mobile (768x1344)": (768, 1344), |
|
|
"21:9 Wide (1536x640)": (1536, 640), |
|
|
"2:3 Classic (832x1216)": (832, 1216), |
|
|
} |
|
|
|
|
|
|
|
|
def hook_input(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]]) -> tuple[dict[str, Tensor]]: |
|
|
""" |
|
|
Перехватывает входные данные CLIP модели для подмены параметров conditioning |
|
|
(размеры, кроп, эстетическая оценка). |
|
|
|
|
|
Args: |
|
|
args: Экземпляр Hook с параметрами |
|
|
mod: CLIP модуль |
|
|
inputs: Tuple с входными данными |
|
|
|
|
|
Returns: |
|
|
Модифицированные входные данные |
|
|
""" |
|
|
if not args.enabled: |
|
|
return inputs |
|
|
|
|
|
assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}" |
|
|
input_data = inputs[0] |
|
|
|
|
|
def create(v: list[float], src: FloatTensor) -> FloatTensor: |
|
|
"""Создает тензор с правильным device и dtype.""" |
|
|
return FloatTensor(v).to(dtype=src.dtype, device=src.device) |
|
|
|
|
|
def put(name: str, v: list[float]) -> None: |
|
|
"""Безопасно заменяет значение в input_data.""" |
|
|
if name in input_data: |
|
|
src = input_data[name] |
|
|
input_data[name] = create(v, src).reshape(src.shape) |
|
|
|
|
|
|
|
|
put('original_size_as_tuple', [args.original_height, args.original_width]) |
|
|
put('crop_coords_top_left', [args.crop_top, args.crop_left]) |
|
|
put('target_size_as_tuple', [args.target_height, args.target_width]) |
|
|
|
|
|
|
|
|
try: |
|
|
current_score = input_data['aesthetic_score'].item() |
|
|
if args.is_positive_prompt(current_score): |
|
|
put('aesthetic_score', [args.aesthetic_score]) |
|
|
else: |
|
|
put('aesthetic_score', [args.negative_aesthetic_score]) |
|
|
except (KeyError, AttributeError) as e: |
|
|
logger.warning(f"[XL Vec] Cannot access aesthetic_score: {e}") |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
|
def hook_output(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]], output: dict) -> dict: |
|
|
""" |
|
|
Перехватывает выход CLIP модели для замены векторов токенов. |
|
|
|
|
|
Args: |
|
|
args: Экземпляр Hook с параметрами |
|
|
mod: CLIP модуль |
|
|
inputs: Входные данные |
|
|
output: Выходные данные с ключом 'vector' |
|
|
|
|
|
Returns: |
|
|
Модифицированные выходные данные |
|
|
""" |
|
|
if not args.enabled: |
|
|
return output |
|
|
|
|
|
try: |
|
|
|
|
|
current_score = inputs[0]['aesthetic_score'].item() |
|
|
prompt, index, multiplier = args.get_prompt_params(current_score) |
|
|
|
|
|
|
|
|
if (prompt is None or len(prompt) == 0) and (index == -1 and multiplier == 1.0): |
|
|
return output |
|
|
|
|
|
|
|
|
if prompt is None or len(prompt) == 0: |
|
|
prompt = inputs[0]['txt'][0] |
|
|
|
|
|
assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}" |
|
|
|
|
|
|
|
|
with args._lock: |
|
|
args.enabled = False |
|
|
try: |
|
|
pooled, token_idx = get_pooled(mod, prompt, index=index) |
|
|
finally: |
|
|
args.enabled = True |
|
|
|
|
|
|
|
|
if output['vector'].shape[1] >= SDXL_POOLED_DIM: |
|
|
output['vector'][:, 0:SDXL_POOLED_DIM] = pooled[:] * multiplier |
|
|
logger.info( |
|
|
f"[XL Vec] Vector override: '{inputs[0]['txt']}' -> '{prompt}' " |
|
|
f"@ token {token_idx} [x{multiplier:.2f}]" |
|
|
) |
|
|
else: |
|
|
logger.error( |
|
|
f"[XL Vec] Vector dimension mismatch: expected >={SDXL_POOLED_DIM}, " |
|
|
f"got {output['vector'].shape[1]}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[XL Vec] Error in hook_output: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class Hook(SDHook): |
|
|
"""Хук для модификации CLIP conditioning в SDXL.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
enabled: bool, |
|
|
p: StableDiffusionProcessing, |
|
|
crop_left: float, crop_top: float, |
|
|
original_width: float, original_height: float, |
|
|
target_width: float, target_height: float, |
|
|
aesthetic_score: float, negative_aesthetic_score: float, |
|
|
extra_prompt: str | None, extra_negative_prompt: str | None, |
|
|
token_index: int | float, negative_token_index: int | float, |
|
|
eot_multiplier: float, negative_eot_multiplier: float, |
|
|
with_hr: bool, |
|
|
base_aesthetic_score: float, |
|
|
): |
|
|
super().__init__(enabled) |
|
|
|
|
|
|
|
|
self._validate_params( |
|
|
aesthetic_score, negative_aesthetic_score, base_aesthetic_score, |
|
|
original_width, original_height, target_width, target_height |
|
|
) |
|
|
|
|
|
self.p = p |
|
|
self.crop_left = float(crop_left) |
|
|
self.crop_top = float(crop_top) |
|
|
self.original_width = float(original_width) |
|
|
self.original_height = float(original_height) |
|
|
self.target_width = float(target_width) |
|
|
self.target_height = float(target_height) |
|
|
self.aesthetic_score = float(aesthetic_score) |
|
|
self.negative_aesthetic_score = float(negative_aesthetic_score) |
|
|
self.extra_prompt = extra_prompt |
|
|
self.extra_negative_prompt = extra_negative_prompt |
|
|
self.token_index = int(token_index) |
|
|
self.negative_token_index = int(negative_token_index) |
|
|
self.eot_multiplier = float(eot_multiplier) |
|
|
self.negative_eot_multiplier = float(negative_eot_multiplier) |
|
|
self.with_hr = bool(with_hr) |
|
|
self.base_aesthetic_score = float(base_aesthetic_score) |
|
|
|
|
|
|
|
|
self._lock = Lock() |
|
|
|
|
|
@staticmethod |
|
|
def _validate_params( |
|
|
aesthetic_score: float, |
|
|
negative_aesthetic_score: float, |
|
|
base_aesthetic_score: float, |
|
|
original_width: float, |
|
|
original_height: float, |
|
|
target_width: float, |
|
|
target_height: float |
|
|
) -> None: |
|
|
"""Валидирует входные параметры.""" |
|
|
for score, name in [ |
|
|
(aesthetic_score, "aesthetic_score"), |
|
|
(negative_aesthetic_score, "negative_aesthetic_score"), |
|
|
(base_aesthetic_score, "base_aesthetic_score") |
|
|
]: |
|
|
if not (0 <= score <= 10): |
|
|
raise ValueError(f"{name} должен быть в диапазоне [0, 10], получено {score}") |
|
|
|
|
|
for size, name in [ |
|
|
(original_width, "original_width"), |
|
|
(original_height, "original_height"), |
|
|
(target_width, "target_width"), |
|
|
(target_height, "target_height") |
|
|
]: |
|
|
if size < 0: |
|
|
raise ValueError(f"{name} не может быть отрицательным, получено {size}") |
|
|
|
|
|
def is_positive_prompt(self, aesthetic_score: float) -> bool: |
|
|
""" |
|
|
Определяет, является ли текущий промпт положительным. |
|
|
|
|
|
Args: |
|
|
aesthetic_score: Текущее значение aesthetic score |
|
|
|
|
|
Returns: |
|
|
True если это positive prompt, False если negative |
|
|
""" |
|
|
return abs(aesthetic_score - self.base_aesthetic_score) < AESTHETIC_SCORE_EPS |
|
|
|
|
|
def get_prompt_params(self, aesthetic_score: float) -> tuple[str | None, int, float]: |
|
|
""" |
|
|
Возвращает параметры промпта в зависимости от aesthetic_score. |
|
|
|
|
|
Args: |
|
|
aesthetic_score: Текущее значение aesthetic score |
|
|
|
|
|
Returns: |
|
|
Tuple (prompt, token_index, multiplier) |
|
|
""" |
|
|
if self.is_positive_prompt(aesthetic_score): |
|
|
return self.extra_prompt, self.token_index, self.eot_multiplier |
|
|
else: |
|
|
return self.extra_negative_prompt, self.negative_token_index, self.negative_eot_multiplier |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Корректно удаляет все хуки.""" |
|
|
try: |
|
|
self.__exit__(None, None, None) |
|
|
except Exception as e: |
|
|
logger.warning(f"[XL Vec] Error during cleanup: {e}") |
|
|
|
|
|
def hook_clip(self, p: StableDiffusionProcessing, clip: nn.Module) -> None: |
|
|
"""Устанавливает хуки на CLIP модель.""" |
|
|
if not hasattr(p.sd_model, 'is_sdxl') or not p.sd_model.is_sdxl: |
|
|
logger.debug("[XL Vec] Model is not SDXL, skipping hooks") |
|
|
return |
|
|
|
|
|
def inp(*args, **kwargs): |
|
|
return hook_input(self, *args, **kwargs) |
|
|
|
|
|
def outp(*args, **kwargs): |
|
|
return hook_output(self, *args, **kwargs) |
|
|
|
|
|
self.hook_layer_pre(clip, inp) |
|
|
self.hook_layer(clip, outp) |
|
|
|
|
|
|
|
|
class Script(scripts.Script): |
|
|
"""Скрипт для управления SDXL conditioning параметрами.""" |
|
|
|
|
|
def title(self) -> str: |
|
|
return NAME |
|
|
|
|
|
def show(self, is_img2img) -> scripts.AlwaysVisible: |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with gr.Accordion(NAME, open=False): |
|
|
with gr.Row(): |
|
|
enabled = gr.Checkbox(label='Enable XL Vec', value=False) |
|
|
with_hr = gr.Checkbox(label='Active on Hires Fix', value=False, visible=False) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### 📐 SDXL Geometry & Size") |
|
|
|
|
|
preset_dropdown = gr.Dropdown( |
|
|
label="⚡ Quick Resolution Preset", |
|
|
choices=list(PRESETS.keys()), |
|
|
value="Manual / Custom", |
|
|
type="value" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
original_width = gr.Slider( |
|
|
minimum=-1, maximum=4096, step=1, value=-1, |
|
|
label='Original Width (-1=auto)' |
|
|
) |
|
|
original_height = gr.Slider( |
|
|
minimum=-1, maximum=4096, step=1, value=-1, |
|
|
label='Original Height (-1=auto)' |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
target_width = gr.Slider( |
|
|
minimum=-1, maximum=4096, step=1, value=-1, |
|
|
label='Target Width (-1=auto)' |
|
|
) |
|
|
target_height = gr.Slider( |
|
|
minimum=-1, maximum=4096, step=1, value=-1, |
|
|
label='Target Height (-1=auto)' |
|
|
) |
|
|
|
|
|
|
|
|
def apply_preset(choice): |
|
|
if choice in PRESETS and PRESETS[choice] is not None: |
|
|
w, h = PRESETS[choice] |
|
|
return w, h, w, h |
|
|
return gr.update(), gr.update(), gr.update(), gr.update() |
|
|
|
|
|
preset_dropdown.change( |
|
|
fn=apply_preset, |
|
|
inputs=[preset_dropdown], |
|
|
outputs=[original_width, original_height, target_width, target_height] |
|
|
) |
|
|
|
|
|
|
|
|
def reset_dropdown(): |
|
|
return "Manual / Custom" |
|
|
|
|
|
for slider in [original_width, original_height, target_width, target_height]: |
|
|
slider.change(fn=reset_dropdown, inputs=None, outputs=[preset_dropdown]) |
|
|
|
|
|
with gr.Accordion("✂️ Crop Settings", open=False): |
|
|
with gr.Row(): |
|
|
crop_left = gr.Slider( |
|
|
minimum=-10000, maximum=10000, step=1, value=0, |
|
|
label='Crop Left' |
|
|
) |
|
|
crop_top = gr.Slider( |
|
|
minimum=-10000, maximum=10000, step=1, value=0, |
|
|
label='Crop Top' |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### 🎨 Aesthetics") |
|
|
with gr.Row(): |
|
|
aesthetic_score = gr.Slider( |
|
|
minimum=0.0, maximum=10.0, step=0.1, value=6.0, |
|
|
label="Positive Aesthetic Score" |
|
|
) |
|
|
negative_aesthetic_score = gr.Slider( |
|
|
minimum=0.0, maximum=10.0, step=0.1, value=2.5, |
|
|
label="Negative Aesthetic Score" |
|
|
) |
|
|
|
|
|
with gr.Accordion("⚙️ Detection Threshold (Advanced)", open=False): |
|
|
base_aesthetic_score = gr.Slider( |
|
|
minimum=0.0, maximum=10.0, step=0.1, value=6.0, |
|
|
label="Base Score Threshold" |
|
|
) |
|
|
gr.Info("Change this ONLY if you modified 'SDXL Aesthetic Score' in WebUI settings.") |
|
|
|
|
|
|
|
|
with gr.Accordion("🧠 Token & Vector Control", open=False): |
|
|
with gr.Row(): |
|
|
eot_multiplier = gr.Slider( |
|
|
minimum=-4.0, maximum=8.0, step=0.05, value=1.0, |
|
|
label='Pos. Vector Mult' |
|
|
) |
|
|
negative_eot_multiplier = gr.Slider( |
|
|
minimum=-4.0, maximum=8.0, step=0.05, value=1.0, |
|
|
label='Neg. Vector Mult' |
|
|
) |
|
|
with gr.Row(): |
|
|
token_index = gr.Slider( |
|
|
minimum=-77, maximum=76, step=1, value=-1, |
|
|
label='Pos. Token Index' |
|
|
) |
|
|
negative_token_index = gr.Slider( |
|
|
minimum=-77, maximum=76, step=1, value=-1, |
|
|
label='Neg. Token Index' |
|
|
) |
|
|
with gr.Row(): |
|
|
extra_prompt = gr.Textbox( |
|
|
lines=1, label='Extra Prompt', |
|
|
placeholder="Override positive prompt text..." |
|
|
) |
|
|
extra_negative_prompt = gr.Textbox( |
|
|
lines=1, label='Extra Negative', |
|
|
placeholder="Override negative prompt text..." |
|
|
) |
|
|
|
|
|
return [ |
|
|
enabled, crop_left, crop_top, original_width, original_height, |
|
|
target_width, target_height, aesthetic_score, negative_aesthetic_score, |
|
|
extra_prompt, extra_negative_prompt, token_index, negative_token_index, |
|
|
eot_multiplier, negative_eot_multiplier, with_hr, |
|
|
base_aesthetic_score |
|
|
] |
|
|
|
|
|
def process( |
|
|
self, p, enabled, crop_left, crop_top, original_width, original_height, |
|
|
target_width, target_height, aesthetic_score, negative_aesthetic_score, |
|
|
extra_prompt, extra_negative_prompt, token_index, negative_token_index, |
|
|
eot_multiplier, negative_eot_multiplier, with_hr, |
|
|
base_aesthetic_score=DEFAULT_AESTHETIC_SCORE |
|
|
): |
|
|
"""Обрабатывает параметры и устанавливает хуки перед генерацией.""" |
|
|
|
|
|
|
|
|
if getattr(self, 'last_hooker', None) is not None: |
|
|
self.last_hooker.cleanup() |
|
|
self.last_hooker = None |
|
|
|
|
|
if not enabled: |
|
|
return |
|
|
|
|
|
|
|
|
if original_width < 0: |
|
|
original_width = p.width |
|
|
if original_height < 0: |
|
|
original_height = p.height |
|
|
if target_width < 0: |
|
|
target_width = p.width |
|
|
if target_height < 0: |
|
|
target_height = p.height |
|
|
|
|
|
try: |
|
|
self.last_hooker = Hook( |
|
|
enabled=True, p=p, |
|
|
crop_left=crop_left, crop_top=crop_top, |
|
|
original_width=original_width, original_height=original_height, |
|
|
target_width=target_width, target_height=target_height, |
|
|
aesthetic_score=aesthetic_score, |
|
|
negative_aesthetic_score=negative_aesthetic_score, |
|
|
extra_prompt=extra_prompt, extra_negative_prompt=extra_negative_prompt, |
|
|
token_index=token_index, negative_token_index=negative_token_index, |
|
|
eot_multiplier=eot_multiplier, |
|
|
negative_eot_multiplier=negative_eot_multiplier, |
|
|
with_hr=with_hr, base_aesthetic_score=base_aesthetic_score |
|
|
) |
|
|
except ValueError as e: |
|
|
logger.error(f"[XL Vec] Invalid parameters: {e}") |
|
|
return |
|
|
|
|
|
self.last_hooker.setup(p) |
|
|
self.last_hooker.__enter__() |
|
|
|
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
f'[{NAME}] Enabled': enabled, |
|
|
f'[{NAME}] Original Size': f"{int(original_width)}x{int(original_height)}", |
|
|
f'[{NAME}] Target Size': f"{int(target_width)}x{int(target_height)}", |
|
|
f'[{NAME}] Aesthetic Score': aesthetic_score, |
|
|
}) |
|
|
|
|
|
if crop_left != 0 or crop_top != 0: |
|
|
p.extra_generation_params[f'[{NAME}] Crop'] = f"{crop_left},{crop_top}" |
|
|
|
|
|
if abs(base_aesthetic_score - DEFAULT_AESTHETIC_SCORE) > AESTHETIC_SCORE_EPS: |
|
|
p.extra_generation_params[f'[{NAME}] Base Score'] = base_aesthetic_score |
|
|
|
|
|
if eot_multiplier != 1.0: |
|
|
p.extra_generation_params[f'[{NAME}] Token Mult'] = eot_multiplier |
|
|
|
|
|
|
|
|
if hasattr(p, 'cached_c'): |
|
|
p.cached_c = [None, None] |
|
|
if hasattr(p, 'cached_uc'): |
|
|
p.cached_uc = [None, None] |
|
|
|
|
|
|
|
|
init_xyz(Script, NAME) |